In [1]:
import numpy as np
from dataclasses import dataclass
from typing import Callable, Optional, Any, Dict

# Define termination messages globally
termination_messages = {
    0: "No convergence.",
    1: "The sufficient decrease condition and the directional derivative condition hold.",
    2: "Relative width of the interval of uncertainty is at most xtol.",
    3: "Number of calls to phi has reached maxfev.",
    4: "The step is at the lower bound alpha_min.",
    5: "The step is at the upper bound alpha_max.",
    6: "Rounding errors prevent further progress.",
    7: "Unable to create finite alpha."
}

def get_termination_message(info_code: int) -> str:
    """
    Retrieves the termination message based on the info code.

    Args:
        info_code (int): The termination code.

    Returns:
        str: The corresponding termination message.
    """
    return termination_messages.get(info_code, "Unknown termination reason.")

@dataclass
class Step:
    """
    Represents a step in the line search.

    Attributes:
        alpha (float): Step size.
        f (float): Function value at the step.
        d (float): Directional derivative at the step.
    """
    alpha: float
    f: float
    d: float


def armijo_ok_step(step0: Step, step: Step, c1: float) -> bool:
    """
    Checks the Armijo (sufficient decrease) condition.

    Args:
        step0 (Step): Initial step before the line search.
        step (Step): Current step being evaluated.
        c1 (float): Armijo condition constant.

    Returns:
        bool: True if the condition is satisfied, False otherwise.
    """
    return step.f <= step0.f + c1 * step0.d * step.alpha


def make_approx_armijo_ok_step(eps: float) -> Callable[[Step, Step, float], bool]:
    """
    Creates an approximate Armijo condition checker with tolerance.

    Args:
        eps (float): Tolerance for the approximation.

    Returns:
        Callable[[Step, Step, float], bool]: A function that checks the approximate Armijo condition.
    """
    def approx_armijo_ok_step(step0: Step, step: Step, c1: float) -> bool:
        return step.f <= step0.f + c1 * step0.d * step.alpha + eps
    return approx_armijo_ok_step


def curvature_ok_step(step0: Step, step: Step, c: float) -> bool:
    """
    Checks the curvature condition.

    Args:
        step0 (Step): Initial step before the line search.
        step (Step): Current step being evaluated.
        c (float): Curvature condition constant.

    Returns:
        bool: True if the condition is satisfied, False otherwise.
    """
    return abs(step.d) <= c * abs(step0.d)


def make_wolfe_ok_step_fn(strong_curvature: bool = True,
                          approx_armijo: bool = False,
                          eps: float = 1e-6) -> Callable[[Step, Step, float, float], bool]:
    """
    Creates a function to check Wolfe conditions.

    Args:
        strong_curvature (bool): If True, checks the strong Wolfe condition.
        approx_armijo (bool): If True, uses an approximate Armijo condition.
        eps (float): Tolerance for the approximate Armijo condition.

    Returns:
        Callable[[Step, Step, float, float], bool]: A function that checks the Wolfe conditions.
    """
    if approx_armijo:
        armijo_fn = make_approx_armijo_ok_step(eps)
    else:
        armijo_fn = armijo_ok_step

    def wolfe_ok_step_fn(step0: Step, step: Step, c1: float, c2: float) -> bool:
        """
        Checks the Wolfe conditions for a given step.

        Args:
            step0 (Step): Initial step before the line search.
            step (Step): Current step being evaluated.
            c1 (float): Armijo condition constant.
            c2 (float): Curvature condition constant.

        Returns:
            bool: True if both conditions are satisfied, False otherwise.
        """
        armijo = armijo_fn(step0, step, c1)
        if strong_curvature:
            # Strong Wolfe: d(alpha) >= c2 * d0
            curvature = step.d >= c2 * step0.d
        else:
            # Standard Wolfe: |d(alpha)| <= c2 * |d0|
            curvature = abs(step.d) <= c2 * abs(step0.d)
        return armijo and curvature

    return wolfe_ok_step_fn


def quadratic_interpolate(stx: float, fx: float, dx: float,
                         stp: float, fp: float) -> float:
    """
    Performs quadratic interpolation to estimate a new step.

    Args:
        stx (float): Step size at the best step so far.
        fx (float): Function value at stx.
        dx (float): Directional derivative at stx.
        stp (float): Current step size.
        fp (float): Function value at stp.

    Returns:
        float: Estimated step size.
    """
    denom = 2 * (fx - fp - dx * (stp - stx))
    if denom == 0:
        return (stx + stp) / 2.0
    alpha = stx - (dx * (stp - stx) ** 2) / denom
    return alpha


def quadratic_interpolate_g(stp: float, dp: float,
                           stx: float, dx: float) -> float:
    """
    Performs quadratic interpolation based on derivative information.

    Args:
        stp (float): Current step size.
        dp (float): Directional derivative at stp.
        stx (float): Step size at the best step so far.
        dx (float): Directional derivative at stx.

    Returns:
        float: Estimated step size.
    """
    denom = dp - dx
    if denom == 0:
        return (stx + stp) / 2.0
    alpha = stx - dx * (stp - stx) / denom
    return alpha


def cubic_interpolate(stx: float, fx: float, dx: float,
                     stp: float, fp: float, dp: float,
                     ignore_warnings: bool = True) -> float:
    """
    Performs cubic interpolation to estimate a new step.

    Args:
        stx (float): Step size at the best step so far.
        fx (float): Function value at stx.
        dx (float): Directional derivative at stx.
        stp (float): Current step size.
        fp (float): Function value at stp.
        dp (float): Directional derivative at stp.
        ignore_warnings (bool): If True, returns NaN instead of raising an error.

    Returns:
        float: Estimated step size, or NaN if interpolation fails.
    """
    h = stp - stx
    if h == 0:
        return (stx + stp) / 2.0

    a = (2 * (fx - fp) + h * (dx + dp)) / h**3
    b = (3 * (fp - fx) - h * (2 * dx + dp)) / h**2
    c = dx

    # Check if 'a' is too small to prevent division by zero
    if np.isclose(a, 0.0):
        if ignore_warnings:
            return np.nan
        else:
            raise ValueError("Cubic interpolation: coefficient 'a' is zero.")

    discrim = b**2 - 3 * a * c
    if discrim < 0:
        if ignore_warnings:
            return np.nan
        else:
            raise ValueError("Cubic interpolation: negative discriminant")
    sqrt_discrim = np.sqrt(discrim)
    x1 = (-b + sqrt_discrim) / (3 * a)
    x2 = (-b - sqrt_discrim) / (3 * a)
    candidates = [x for x in [x1, x2] if 0 < x < h]
    if not candidates:
        if ignore_warnings:
            return np.nan
        else:
            raise ValueError("Cubic interpolation: no valid roots")
    alpha_new = stx + min(candidates, key=lambda x: abs(x))
    return alpha_new


def ensure_min_alpha(alpha: float, x: float, y: float, minshrink: float = 0.001) -> float:
    """
    Ensures that the new alpha is not too close to the interval bounds.

    Args:
        alpha (float): Suggested step size.
        x (float): One end of the interval.
        y (float): Other end of the interval.
        minshrink (float): Minimum fraction of the interval range that alpha can increase by.

    Returns:
        float: Modified step size.
    """
    l = min(x, y)
    length = abs(x - y)
    return max(l + minshrink * length, alpha)


def modify_step(step: Step, dgtest: float) -> Step:
    """
    Modifies a step based on the sufficient decrease condition.

    Args:
        step (Step): Current step.
        dgtest (float): Product of the initial directional derivative and the sufficient decrease constant.

    Returns:
        Step: Modified step.
    """
    return Step(
        alpha=step.alpha,
        f=step.f - step.alpha * dgtest,
        d=step.d - dgtest
    )


def unmodify_step(stepm: Step, dgtest: float) -> Step:
    """
    Reverses the modification of a step.

    Args:
        stepm (Step): Modified step.
        dgtest (float): Product of the initial directional derivative and the sufficient decrease constant.

    Returns:
        Step: Unmodified step.
    """
    return Step(
        alpha=stepm.alpha,
        f=stepm.f + stepm.alpha * dgtest,
        d=stepm.d + dgtest
    )


def find_finite(phi: Callable[[float], Dict[str, float]],
                alpha: float, maxfev: int,
                min_alpha: Optional[float] = None) -> Dict[str, Any]:
    """
    Evaluates phi(alpha) and ensures that the function value and derivative are finite.
    If not, performs bisection to find a finite step size.

    Args:
        phi (Callable[[float], Dict[str, float]]): Objective function along the search direction.
        alpha (float): Current step size.
        maxfev (int): Maximum number of function evaluations allowed.
        min_alpha (Optional[float]): Minimum acceptable step size.

    Returns:
        Dict[str, Any]: Contains 'ok' (bool), 'step' (Step or None), and 'nfn' (int).
    """
    nfn = 0
    alpha_lower = min_alpha if min_alpha is not None else 0.0
    step = None

    while nfn < maxfev:
        try:
            result = phi(alpha)
            f = result['f']
            d = result['d']
            if np.isfinite(f) and np.isfinite(d):
                step = Step(alpha=alpha, f=f, d=d)
                return {'ok': True, 'step': step, 'nfn': nfn + 1}
        except (KeyError, TypeError):
            # Handle specific exceptions related to phi's output
            pass
        except Exception as e:
            # Optionally, re-raise unexpected exceptions
            raise e

        # If not finite, reduce alpha by half
        if alpha > alpha_lower:
            alpha = (alpha + alpha_lower) / 2.0
        else:
            break
        nfn += 1

    return {'ok': False, 'step': None, 'nfn': nfn}


def check_convergence(step0: Step, step: Step, brackt: bool, infoc: int,
                     stmin: float, stmax: float,
                     alpha_min: float, alpha_max: float,
                     c1: float, c2: float, nfev: int,
                     maxfev: int, xtol: float,
                     armijo_check_fn: Callable[[Step, Step, float], bool],
                     wolfe_ok_step_fn: Callable[[Step, Step, float, float], bool],
                     verbose: bool = False) -> int:
    """
    Checks whether the line search has converged based on various criteria.

    Args:
        step0 (Step): Initial step before the line search.
        step (Step): Current step being evaluated.
        brackt (bool): Whether the step has been bracketed.
        infoc (int): Return code from the last step size update.
        stmin (float): Smallest value of the step size interval.
        stmax (float): Largest value of the step size interval.
        alpha_min (float): Minimum acceptable step size.
        alpha_max (float): Maximum acceptable step size.
        c1 (float): Armijo condition constant.
        c2 (float): Curvature condition constant.
        nfev (int): Current number of function evaluations.
        maxfev (int): Maximum number of function evaluations allowed.
        xtol (float): Relative width tolerance.
        armijo_check_fn (Callable[[Step, Step, float], bool]): Function to check the Armijo condition.
        wolfe_ok_step_fn (Callable[[Step, Step, float, float], bool]): Function to check the Wolfe conditions.
        verbose (bool): If True, prints debug information.

    Returns:
        int: Integer code indicating the convergence state.
            0: No convergence.
            1: The sufficient decrease condition and the directional derivative condition hold.
            2: Relative width of the interval of uncertainty is at most xtol.
            3: Number of calls to phi has reached maxfev.
            4: The step is at the lower bound alpha_min.
            5: The step is at the upper bound alpha_max.
            6: Rounding errors prevent further progress.
            7: Unable to create finite alpha.
    """
    info = 0
    if (brackt and (step.alpha <= stmin or step.alpha >= stmax)) or (infoc == 0):
        if verbose:
            print(f"MT: Rounding errors prevent further progress: stmin = {stmin}, stmax = {stmax}")
        # Rounding errors prevent further progress
        info = 6

    if (step.alpha == alpha_max and armijo_check_fn(step0, step, c1) and
            not curvature_ok_step(step0, step, c2)):
        # Reached alpha_max
        info = 5
        if verbose:
            print("MT: Reached alpha max")

    if (step.alpha == alpha_min and
            (not armijo_check_fn(step0, step, c1) or
             curvature_ok_step(step0, step, c2))):
        # Reached alpha_min
        info = 4
        if verbose:
            print("MT: Reached alpha min")

    if nfev >= maxfev:
        # Maximum number of function evaluations reached
        info = 3
        if verbose:
            print("MT: Exceeded maximum number of function evaluations")

    if brackt and (stmax - stmin) <= xtol * stmax:
        # Interval width is below xtol
        info = 2
        if verbose:
            print(f"MT: Interval width is <= xtol: {xtol * stmax}")

    if wolfe_ok_step_fn(step0, step, c1, c2):
        # Success
        info = 1
        if verbose:
            print("Success! Step satisfies Wolfe conditions.")

    return info


def cstep(stepx: Step, stepy: Step, step: Step, brackt: bool,
          stpmin: float, stpmax: float,
          safeguard_cubic: bool = False,
          verbose: bool = False) -> Dict[str, Any]:
    """
    Updates the interval of uncertainty and computes a new trial step.

    Args:
        stepx (Step): One side of the updated step interval.
        stepy (Step): Other side of the updated step interval.
        step (Step): Current trial step.
        brackt (bool): Whether the step has been bracketed.
        stpmin (float): Minimum allowed step size.
        stpmax (float): Maximum allowed step size.
        safeguard_cubic (bool): If True, ensures the cubic step isn't too close to the bounds.
        verbose (bool): If True, prints debug information.

    Returns:
        Dict[str, Any]: Contains updated 'stepx', 'stepy', 'step', 'brackt', and 'info'.
    """
    stx = stepx.alpha
    fx = stepx.f
    dx = stepx.d

    sty = stepy.alpha
    fy = stepy.f
    dy = stepy.d

    stp = step.alpha
    fp = step.f
    dp = step.d

    delta = 0.66
    info = 0

    # Check the input parameters for errors
    if (brackt and (stp <= min(stx, sty) or stp >= max(stx, sty))) or \
            (dx * (stp - stx) >= 0.0) or (stpmax < stpmin):
        # Improper input parameters
        if verbose:
            print("cstep: Improper input parameters detected.")
        return {
            'stepx': stepx,
            'stepy': stepy,
            'step': step,
            'brackt': brackt,
            'info': info
        }

    # Determine if the derivatives have opposite sign
    sgnd = dp * np.sign(dx)  # Equivalent to dp * sign(dx)

    # First case: fp > fx
    if fp > fx:
        info = 1
        bound = True

        stpc = cubic_interpolate(stx, fx, dx, stp, fp, dp, ignore_warnings=True)
        stpq = quadratic_interpolate(stx, fx, dx, stp, fp)

        if np.isnan(stpc):
            stpf = stpq
            if verbose:
                print("cstep: Cubic interpolation failed, using quadratic interpolation.")
        else:
            if abs(stpc - stx) < abs(stpq - stx):
                if safeguard_cubic:
                    stpf = ensure_min_alpha(stpc, stx, sty)
                    if verbose:
                        print("cstep: Using safeguarded cubic interpolation.")
                else:
                    stpf = stpc
                    if verbose:
                        print("cstep: Using cubic interpolation.")
            else:
                stpf = stpc + (stpq - stpc) / 2.0
                if verbose:
                    print("cstep: Using combination of cubic and quadratic interpolation.")

        brackt = True

    # Second case: sgnd < 0
    elif sgnd < 0.0:
        info = 2
        bound = False

        stpc = cubic_interpolate(stx, fx, dx, stp, fp, dp, ignore_warnings=True)
        stpq = quadratic_interpolate_g(stp, dp, stx, dx)

        if np.isnan(stpc):
            stpf = stpq
            if verbose:
                print("cstep: Cubic interpolation failed, using quadratic interpolation based on derivatives.")
        else:
            if abs(stpc - stp) > abs(stpq - stp):
                if safeguard_cubic:
                    stpf = ensure_min_alpha(stpc, stx, sty)
                    if verbose:
                        print("cstep: Using safeguarded cubic interpolation.")
                else:
                    stpf = stpc
                    if verbose:
                        print("cstep: Using cubic interpolation.")
            else:
                stpf = stpq
                if verbose:
                    print("cstep: Using quadratic interpolation based on derivatives.")

        brackt = True

    # Third case: |dp| < |dx|
    elif abs(dp) < abs(dx):
        info = 3
        bound = True
        theta = 3.0 * (fx - fp) / (stp - stx) + dx + dp
        s = np.linalg.norm([theta, dx, dp], ord=np.inf)  # Infinity norm
        gamma_sq = (theta / s)**2 - (dx / s) * (dp / s)
        gamma = s * np.sqrt(max(0.0, gamma_sq))
        if stp > stx:
            gamma = -gamma
        p = (gamma - dp) + theta
        q = (gamma + (dx - dp)) + gamma
        if q == 0.0:
            r = 0.0
        else:
            r = p / q

        if r < 0.0 and gamma != 0.0:
            stpc = stp + r * (stx - stp)
        elif stp > stx:
            stpc = stpmax
        else:
            stpc = stpmin

        stpq = quadratic_interpolate_g(stp, dp, stx, dx)

        if brackt:
            if abs(stp - stpc) < abs(stp - stpq):
                stpf = stpc
                if verbose:
                    print("cstep: Using cubic interpolation within bracket.")
            else:
                stpf = stpq
                if verbose:
                    print("cstep: Using quadratic interpolation within bracket.")
        else:
            if abs(stp - stpc) > abs(stp - stpq):
                stpf = stpc
                if verbose:
                    print("cstep: Using cubic interpolation outside bracket.")
            else:
                stpf = stpq
                if verbose:
                    print("cstep: Using quadratic interpolation outside bracket.")

    # Fourth case
    else:
        info = 4
        bound = False
        if brackt:
            stpc = cubic_interpolate(sty, fy, dy, stp, fp, dp, ignore_warnings=True)
            if np.isnan(stpc):
                stpc = (sty + stp) / 2.0
                if verbose:
                    print("cstep: Cubic interpolation failed within bracket, using bisection.")
            if safeguard_cubic:
                stpf = ensure_min_alpha(stpc, stx, sty)
                if verbose:
                    print("cstep: Using safeguarded cubic interpolation within bracket.")
            else:
                stpf = stpc
                if verbose:
                    print("cstep: Using cubic interpolation within bracket.")
        elif stp > stx:
            stpf = stpmax
            if verbose:
                print("cstep: Using upper bound for step size.")
        else:
            stpf = stpmin
            if verbose:
                print("cstep: Using lower bound for step size.")

    # Update the interval of uncertainty
    if fp > fx:
        stepy = Step(alpha=stp, f=fp, d=dp)
    else:
        if sgnd < 0.0:
            stepy = stepx
        stepx = Step(alpha=stp, f=fp, d=dp)

    # Compute the new step and safeguard it
    stpf = min(stpmax, stpf)
    stpf = max(stpmin, stpf)
    stp = stpf

    if brackt and bound:
        # If the new step is too close to an endpoint, replace with weighted bisection
        if verbose:
            print("cstep: Step too close to end point, weighted bisection")
        stb = stx + delta * (sty - stx)
        if sty > stx:
            stp = min(stb, stp)
        else:
            stp = max(stb, stp)

    # Update step
    step = Step(alpha=stp, f=fp, d=dp)

    return {
        'stepx': stepx,
        'stepy': stepy,
        'step': step,
        'brackt': brackt,
        'info': info
    }


def cvsrch(phi: Callable[[float], Dict[str, float]],
           step0: Step, alpha: float = 1.0,
           c1: float = 1e-4, c2: float = 0.1,
           xtol: float = np.finfo(float).eps,
           alpha_min: float = 0.0, alpha_max: float = np.inf,
           maxfev: int = int(1e9), delta: float = 0.66,
           armijo_check_fn: Callable[[Step, Step, float], bool] = armijo_ok_step,
           wolfe_ok_step_fn: Optional[Callable[[Step, Step, float, float], bool]] = None,
           safeguard_cubic: bool = False,
           verbose: bool = False) -> Dict[str, Any]:
    """
    Performs the More-Thuente line search to find an acceptable step size.

    Args:
        phi (Callable[[float], Dict[str, float]]): Objective function along the search direction.
        step0 (Step): Initial step before the line search.
        alpha (float, optional): Initial guess for the step size. Defaults to 1.0.
        c1 (float, optional): Armijo condition constant. Defaults to 1e-4.
        c2 (float, optional): Curvature condition constant. Defaults to 0.1.
        xtol (float, optional): Relative width tolerance. Defaults to machine epsilon.
        alpha_min (float, optional): Minimum acceptable step size. Defaults to 0.0.
        alpha_max (float, optional): Maximum acceptable step size. Defaults to infinity.
        maxfev (int, optional): Maximum number of function evaluations allowed. Defaults to a large number.
        delta (float, optional): Minimum fraction of the interval range that the step size must decrease by. Defaults to 0.66.
        armijo_check_fn (Callable[[Step, Step, float], bool], optional): Function to check the Armijo condition. Defaults to `armijo_ok_step`.
        wolfe_ok_step_fn (Optional[Callable[[Step, Step, float, float], bool]], optional): Function to check the Wolfe conditions. Must be provided.
        safeguard_cubic (bool, optional): If True, ensures the cubic step isn't too close to the bounds. Defaults to False.
        verbose (bool, optional): If True, prints debug information. Defaults to False.

    Returns:
        Dict[str, Any]: Contains the best step found (`step`), number of function evaluations (`nfn`), gradient evaluations (`ngr`), and termination info (`info` and `message`).
    """
    if wolfe_ok_step_fn is None:
        raise ValueError("wolfe_ok_step_fn must be provided")

    # Initialize
    xtrapf = 4.0
    infoc = 1

    if maxfev == 0:
        message = get_termination_message(3)
        return {'step': step0, 'nfn': 0, 'ngr': 0, 'info': 3, 'message': message}

    # Check that the direction is descent
    if step0.d >= 0.0:
        message = get_termination_message(6)
        return {'step': step0, 'nfn': 0, 'ngr': 0, 'info': 6, 'message': message}

    dgtest = c1 * step0.d

    # Initialize local variables
    bracketed = False
    brackt = False
    stage1 = True
    nfev = 0

    width = alpha_max - alpha_min
    width_old = 2.0 * width

    stepx = step0
    stepy = step0
    step = Step(alpha=alpha, f=np.nan, d=np.nan)  # To be filled by find_finite

    while True:
        # Set the minimum and maximum steps to correspond to the present interval of uncertainty
        if brackt:
            stmin = min(stepx.alpha, stepy.alpha)
            stmax = max(stepx.alpha, stepy.alpha)
        else:
            stmin = stepx.alpha
            # Handle cases where step.alpha might be NaN initially
            if np.isnan(step.alpha):
                stmax = alpha_max
            else:
                stmax = step.alpha + xtrapf * (step.alpha - stepx.alpha)

        # Force the step to be within the bounds alpha_max and alpha_min
        if not np.isnan(step.alpha):
            step.alpha = max(step.alpha, alpha_min)
            step.alpha = min(step.alpha, alpha_max)
        else:
            step.alpha = alpha_min

        if verbose:
            print(f"Bracket: [{stmin}, {stmax}] alpha = {step.alpha}")

        # Evaluate the function and gradient at alpha
        # and compute the directional derivative
        ffres = find_finite(phi, step.alpha, maxfev - nfev, min_alpha=stmin)
        nfev += ffres['nfn']
        if not ffres['ok']:
            if verbose:
                print("Unable to create finite alpha")
            message = get_termination_message(7)
            return {'step': step0, 'nfn': nfev, 'ngr': nfev, 'info': 7, 'message': message}
        step = ffres['step']

        # Test for convergence
        info = check_convergence(step0, step, brackt, infoc, stmin, stmax,
                                 alpha_min, alpha_max, c1, c2, nfev,
                                 maxfev, xtol,
                                 armijo_check_fn=armijo_check_fn,
                                 wolfe_ok_step_fn=wolfe_ok_step_fn,
                                 verbose=verbose)

        # Check for termination
        if info != 0:
            # If an unusual termination is to occur, then set step to the best step found
            if info in [2, 3, 6]:
                step = stepx
            message = get_termination_message(info)
            if verbose:
                print(f"alpha = {step.alpha}")
            return {'step': step, 'nfn': nfev, 'ngr': nfev, 'info': info, 'message': message}

        # In the first stage we seek a step for which the modified
        # function has a nonpositive value and nonnegative derivative
        if stage1 and wolfe_ok_step_fn(step0, step, c1, min(c1, c2)):
            stage1 = False

        # A modified function is used to predict the step only if
        # we have not obtained a step for which the modified
        # function has a nonpositive function value and nonnegative
        # derivative, and if a lower function value has been
        # obtained but the decrease is not sufficient
        if stage1 and step.f <= stepx.f and not armijo_check_fn(step0, step, c1):
            # Define the modified function and derivative values
            stepxm = modify_step(stepx, dgtest)
            stepym = modify_step(stepy, dgtest)
            stepm = modify_step(step, dgtest)

            step_result = cstep(stepxm, stepym, stepm, brackt, stmin, stmax,
                                safeguard_cubic=safeguard_cubic,
                                verbose=verbose)

            brackt = step_result['brackt']
            infoc = step_result['info']
            stepxm = step_result['stepx']
            stepym = step_result['stepy']
            stepm = step_result['step']

            # Reset the function and gradient values for f
            stepx = unmodify_step(stepxm, dgtest)
            stepy = unmodify_step(stepym, dgtest)
            step = Step(alpha=stepm.alpha,
                        f=stepm.f + stepm.alpha * dgtest,
                        d=stepm.d + dgtest)
        else:
            # Call cstep to update the interval of uncertainty
            # and to compute the new step.
            step_result = cstep(stepx, stepy, step, brackt, stmin, stmax,
                                safeguard_cubic=safeguard_cubic,
                                verbose=verbose)
            brackt = step_result['brackt']
            infoc = step_result['info']
            stepx = step_result['stepx']
            stepy = step_result['stepy']
            step = step_result['step']

        if not bracketed and brackt:
            bracketed = True
            if verbose:
                print("Bracketed")

        # Force a sufficient decrease in the size of the interval of uncertainty
        if brackt:
            width_new = abs(stepy.alpha - stepx.alpha)
            if width_new >= delta * width_old:
                if verbose:
                    print("Interval did not decrease sufficiently: bisecting")
                step.alpha = stepx.alpha + 0.5 * (stepy.alpha - stepx.alpha)
            width_old = width
            width = width_new


def check_convergence(step0: Step, step: Step, brackt: bool, infoc: int,
                     stmin: float, stmax: float,
                     alpha_min: float, alpha_max: float,
                     c1: float, c2: float, nfev: int,
                     maxfev: int, xtol: float,
                     armijo_check_fn: Callable[[Step, Step, float], bool],
                     wolfe_ok_step_fn: Callable[[Step, Step, float, float], bool],
                     verbose: bool = False) -> int:
    """
    Checks whether the line search has converged based on various criteria.

    Args:
        step0 (Step): Initial step before the line search.
        step (Step): Current step being evaluated.
        brackt (bool): Whether the step has been bracketed.
        infoc (int): Return code from the last step size update.
        stmin (float): Smallest value of the step size interval.
        stmax (float): Largest value of the step size interval.
        alpha_min (float): Minimum acceptable step size.
        alpha_max (float): Maximum acceptable step size.
        c1 (float): Armijo condition constant.
        c2 (float): Curvature condition constant.
        nfev (int): Current number of function evaluations.
        maxfev (int): Maximum number of function evaluations allowed.
        xtol (float): Relative width tolerance.
        armijo_check_fn (Callable[[Step, Step, float], bool]): Function to check the Armijo condition.
        wolfe_ok_step_fn (Callable[[Step, Step, float, float], bool]): Function to check the Wolfe conditions.
        verbose (bool): If True, prints debug information.

    Returns:
        int: Integer code indicating the convergence state.
            0: No convergence.
            1: The sufficient decrease condition and the directional derivative condition hold.
            2: Relative width of the interval of uncertainty is at most xtol.
            3: Number of calls to phi has reached maxfev.
            4: The step is at the lower bound alpha_min.
            5: The step is at the upper bound alpha_max.
            6: Rounding errors prevent further progress.
            7: Unable to create finite alpha.
    """
    info = 0
    if (brackt and (step.alpha <= stmin or step.alpha >= stmax)) or (infoc == 0):
        if verbose:
            print(f"MT: Rounding errors prevent further progress: stmin = {stmin}, stmax = {stmax}")
        # Rounding errors prevent further progress
        info = 6

    if (step.alpha == alpha_max and armijo_check_fn(step0, step, c1) and
            not curvature_ok_step(step0, step, c2)):
        # Reached alpha_max
        info = 5
        if verbose:
            print("MT: Reached alpha max")

    if (step.alpha == alpha_min and
            (not armijo_check_fn(step0, step, c1) or
             curvature_ok_step(step0, step, c2))):
        # Reached alpha_min
        info = 4
        if verbose:
            print("MT: Reached alpha min")

    if nfev >= maxfev:
        # Maximum number of function evaluations reached
        info = 3
        if verbose:
            print("MT: Exceeded maximum number of function evaluations")

    if brackt and (stmax - stmin) <= xtol * stmax:
        # Interval width is below xtol
        info = 2
        if verbose:
            print(f"MT: Interval width is <= xtol: {xtol * stmax}")

    if wolfe_ok_step_fn(step0, step, c1, c2):
        # Success
        info = 1
        if verbose:
            print("Success! Step satisfies Wolfe conditions.")

    return info


def cstep(stepx: Step, stepy: Step, step: Step, brackt: bool,
          stpmin: float, stpmax: float,
          safeguard_cubic: bool = False,
          verbose: bool = False) -> Dict[str, Any]:
    """
    Updates the interval of uncertainty and computes a new trial step.

    Args:
        stepx (Step): One side of the updated step interval.
        stepy (Step): Other side of the updated step interval.
        step (Step): Current trial step.
        brackt (bool): Whether the step has been bracketed.
        stpmin (float): Minimum allowed step size.
        stpmax (float): Maximum allowed step size.
        safeguard_cubic (bool): If True, ensures the cubic step isn't too close to the bounds.
        verbose (bool): If True, prints debug information.

    Returns:
        Dict[str, Any]: Contains updated 'stepx', 'stepy', 'step', 'brackt', and 'info'.
    """
    stx = stepx.alpha
    fx = stepx.f
    dx = stepx.d

    sty = stepy.alpha
    fy = stepy.f
    dy = stepy.d

    stp = step.alpha
    fp = step.f
    dp = step.d

    delta = 0.66
    info = 0

    # Check the input parameters for errors
    if (brackt and (stp <= min(stx, sty) or stp >= max(stx, sty))) or \
            (dx * (stp - stx) >= 0.0) or (stpmax < stpmin):
        # Improper input parameters
        if verbose:
            print("cstep: Improper input parameters detected.")
        return {
            'stepx': stepx,
            'stepy': stepy,
            'step': step,
            'brackt': brackt,
            'info': info
        }

    # Determine if the derivatives have opposite sign
    sgnd = dp * np.sign(dx)  # Equivalent to dp * sign(dx)

    # First case: fp > fx
    if fp > fx:
        info = 1
        bound = True

        stpc = cubic_interpolate(stx, fx, dx, stp, fp, dp, ignore_warnings=True)
        stpq = quadratic_interpolate(stx, fx, dx, stp, fp)

        if np.isnan(stpc):
            stpf = stpq
            if verbose:
                print("cstep: Cubic interpolation failed, using quadratic interpolation.")
        else:
            if abs(stpc - stx) < abs(stpq - stx):
                if safeguard_cubic:
                    stpf = ensure_min_alpha(stpc, stx, sty)
                    if verbose:
                        print("cstep: Using safeguarded cubic interpolation.")
                else:
                    stpf = stpc
                    if verbose:
                        print("cstep: Using cubic interpolation.")
            else:
                stpf = stpc + (stpq - stpc) / 2.0
                if verbose:
                    print("cstep: Using combination of cubic and quadratic interpolation.")

        brackt = True

    # Second case: sgnd < 0
    elif sgnd < 0.0:
        info = 2
        bound = False

        stpc = cubic_interpolate(stx, fx, dx, stp, fp, dp, ignore_warnings=True)
        stpq = quadratic_interpolate_g(stp, dp, stx, dx)

        if np.isnan(stpc):
            stpf = stpq
            if verbose:
                print("cstep: Cubic interpolation failed, using quadratic interpolation based on derivatives.")
        else:
            if abs(stpc - stp) > abs(stpq - stp):
                if safeguard_cubic:
                    stpf = ensure_min_alpha(stpc, stx, sty)
                    if verbose:
                        print("cstep: Using safeguarded cubic interpolation.")
                else:
                    stpf = stpc
                    if verbose:
                        print("cstep: Using cubic interpolation.")
            else:
                stpf = stpq
                if verbose:
                    print("cstep: Using quadratic interpolation based on derivatives.")

        brackt = True

    # Third case: |dp| < |dx|
    elif abs(dp) < abs(dx):
        info = 3
        bound = True
        theta = 3.0 * (fx - fp) / (stp - stx) + dx + dp
        s = np.linalg.norm([theta, dx, dp], ord=np.inf)  # Infinity norm
        gamma_sq = (theta / s)**2 - (dx / s) * (dp / s)
        gamma = s * np.sqrt(max(0.0, gamma_sq))
        if stp > stx:
            gamma = -gamma
        p = (gamma - dp) + theta
        q = (gamma + (dx - dp)) + gamma
        if q == 0.0:
            r = 0.0
        else:
            r = p / q

        if r < 0.0 and gamma != 0.0:
            stpc = stp + r * (stx - stp)
        elif stp > stx:
            stpc = stpmax
        else:
            stpc = stpmin

        stpq = quadratic_interpolate_g(stp, dp, stx, dx)

        if brackt:
            if abs(stp - stpc) < abs(stp - stpq):
                stpf = stpc
                if verbose:
                    print("cstep: Using cubic interpolation within bracket.")
            else:
                stpf = stpq
                if verbose:
                    print("cstep: Using quadratic interpolation within bracket.")
        else:
            if abs(stp - stpc) > abs(stp - stpq):
                stpf = stpc
                if verbose:
                    print("cstep: Using cubic interpolation outside bracket.")
            else:
                stpf = stpq
                if verbose:
                    print("cstep: Using quadratic interpolation outside bracket.")

    # Fourth case
    else:
        info = 4
        bound = False
        if brackt:
            stpc = cubic_interpolate(sty, fy, dy, stp, fp, dp, ignore_warnings=True)
            if np.isnan(stpc):
                stpc = (sty + stp) / 2.0
                if verbose:
                    print("cstep: Cubic interpolation failed within bracket, using bisection.")
            if safeguard_cubic:
                stpf = ensure_min_alpha(stpc, stx, sty)
                if verbose:
                    print("cstep: Using safeguarded cubic interpolation within bracket.")
            else:
                stpf = stpc
                if verbose:
                    print("cstep: Using cubic interpolation within bracket.")
        elif stp > stx:
            stpf = stpmax
            if verbose:
                print("cstep: Using upper bound for step size.")
        else:
            stpf = stpmin
            if verbose:
                print("cstep: Using lower bound for step size.")

    # Update the interval of uncertainty
    if fp > fx:
        stepy = Step(alpha=stp, f=fp, d=dp)
    else:
        if sgnd < 0.0:
            stepy = stepx
        stepx = Step(alpha=stp, f=fp, d=dp)

    # Compute the new step and safeguard it
    stpf = min(stpmax, stpf)
    stpf = max(stpmin, stpf)
    stp = stpf

    if brackt and bound:
        # If the new step is too close to an endpoint, replace with weighted bisection
        if verbose:
            print("cstep: Step too close to end point, weighted bisection")
        stb = stx + delta * (sty - stx)
        if sty > stx:
            stp = min(stb, stp)
        else:
            stp = max(stb, stp)

    # Update step
    step = Step(alpha=stp, f=fp, d=dp)

    return {
        'stepx': stepx,
        'stepy': stepy,
        'step': step,
        'brackt': brackt,
        'info': info
    }


def cvsrch(phi: Callable[[float], Dict[str, float]],
           step0: Step, alpha: float = 1.0,
           c1: float = 1e-4, c2: float = 0.1,
           xtol: float = np.finfo(float).eps,
           alpha_min: float = 0.0, alpha_max: float = np.inf,
           maxfev: int = int(1e9), delta: float = 0.66,
           armijo_check_fn: Callable[[Step, Step, float], bool] = armijo_ok_step,
           wolfe_ok_step_fn: Optional[Callable[[Step, Step, float, float], bool]] = None,
           safeguard_cubic: bool = False,
           verbose: bool = False) -> Dict[str, Any]:
    """
    Performs the More-Thuente line search to find an acceptable step size.

    Args:
        phi (Callable[[float], Dict[str, float]]): Objective function along the search direction.
        step0 (Step): Initial step before the line search.
        alpha (float, optional): Initial guess for the step size. Defaults to 1.0.
        c1 (float, optional): Armijo condition constant. Defaults to 1e-4.
        c2 (float, optional): Curvature condition constant. Defaults to 0.1.
        xtol (float, optional): Relative width tolerance. Defaults to machine epsilon.
        alpha_min (float, optional): Minimum acceptable step size. Defaults to 0.0.
        alpha_max (float, optional): Maximum acceptable step size. Defaults to infinity.
        maxfev (int, optional): Maximum number of function evaluations allowed. Defaults to a large number.
        delta (float, optional): Minimum fraction of the interval range that the step size must decrease by. Defaults to 0.66.
        armijo_check_fn (Callable[[Step, Step, float], bool], optional): Function to check the Armijo condition. Defaults to `armijo_ok_step`.
        wolfe_ok_step_fn (Optional[Callable[[Step, Step, float, float], bool]], optional): Function to check the Wolfe conditions. Must be provided.
        safeguard_cubic (bool, optional): If True, ensures the cubic step isn't too close to the bounds. Defaults to False.
        verbose (bool, optional): If True, prints debug information. Defaults to False.

    Returns:
        Dict[str, Any]: Contains the best step found (`step`), number of function evaluations (`nfn`),
                        gradient evaluations (`ngr`), and termination info (`info` and `message`).
    """
    if wolfe_ok_step_fn is None:
        raise ValueError("wolfe_ok_step_fn must be provided")

    # Initialize
    xtrapf = 4.0
    infoc = 1

    if maxfev == 0:
        message = get_termination_message(3)
        return {'step': step0, 'nfn': 0, 'ngr': 0, 'info': 3, 'message': message}

    # Check that the direction is descent
    if step0.d >= 0.0:
        message = get_termination_message(6)
        return {'step': step0, 'nfn': 0, 'ngr': 0, 'info': 6, 'message': message}

    dgtest = c1 * step0.d

    # Initialize local variables
    bracketed = False
    brackt = False
    stage1 = True
    nfev = 0

    width = alpha_max - alpha_min
    width_old = 2.0 * width

    stepx = step0
    stepy = step0
    step = Step(alpha=alpha, f=np.nan, d=np.nan)  # To be filled by find_finite

    while True:
        # Set the minimum and maximum steps to correspond to the present interval of uncertainty
        if brackt:
            stmin = min(stepx.alpha, stepy.alpha)
            stmax = max(stepx.alpha, stepy.alpha)
        else:
            stmin = stepx.alpha
            # Handle cases where step.alpha might be NaN initially
            if np.isnan(step.alpha):
                stmax = alpha_max
            else:
                stmax = step.alpha + xtrapf * (step.alpha - stepx.alpha)

        # Force the step to be within the bounds alpha_max and alpha_min
        if not np.isnan(step.alpha):
            step.alpha = max(step.alpha, alpha_min)
            step.alpha = min(step.alpha, alpha_max)
        else:
            step.alpha = alpha_min

        if verbose:
            print(f"Bracket: [{stmin}, {stmax}] alpha = {step.alpha}")

        # Evaluate the function and gradient at alpha
        # and compute the directional derivative
        ffres = find_finite(phi, step.alpha, maxfev - nfev, min_alpha=stmin)
        nfev += ffres['nfn']
        if not ffres['ok']:
            if verbose:
                print("Unable to create finite alpha")
            message = get_termination_message(7)
            return {'step': step0, 'nfn': nfev, 'ngr': nfev, 'info': 7, 'message': message}
        step = ffres['step']

        # Test for convergence
        info = check_convergence(step0, step, brackt, infoc, stmin, stmax,
                                 alpha_min, alpha_max, c1, c2, nfev,
                                 maxfev, xtol,
                                 armijo_check_fn=armijo_check_fn,
                                 wolfe_ok_step_fn=wolfe_ok_step_fn,
                                 verbose=verbose)

        # Check for termination
        if info != 0:
            # If an unusual termination is to occur, then set step to the best step found
            if info in [2, 3, 6]:
                step = stepx
            message = get_termination_message(info)
            if verbose:
                print(f"alpha = {step.alpha}")
            return {'step': step, 'nfn': nfev, 'ngr': nfev, 'info': info, 'message': message}

        # In the first stage we seek a step for which the modified
        # function has a nonpositive value and nonnegative derivative
        if stage1 and wolfe_ok_step_fn(step0, step, c1, min(c1, c2)):
            stage1 = False

        # A modified function is used to predict the step only if
        # we have not obtained a step for which the modified
        # function has a nonpositive function value and nonnegative
        # derivative, and if a lower function value has been
        # obtained but the decrease is not sufficient
        if stage1 and step.f <= stepx.f and not armijo_check_fn(step0, step, c1):
            # Define the modified function and derivative values
            stepxm = modify_step(stepx, dgtest)
            stepym = modify_step(stepy, dgtest)
            stepm = modify_step(step, dgtest)

            step_result = cstep(stepxm, stepym, stepm, brackt, stmin, stmax,
                                safeguard_cubic=safeguard_cubic,
                                verbose=verbose)

            brackt = step_result['brackt']
            infoc = step_result['info']
            stepxm = step_result['stepx']
            stepym = step_result['stepy']
            stepm = step_result['step']

            # Reset the function and gradient values for f
            stepx = unmodify_step(stepxm, dgtest)
            stepy = unmodify_step(stepym, dgtest)
            step = Step(alpha=stepm.alpha,
                        f=stepm.f + stepm.alpha * dgtest,
                        d=stepm.d + dgtest)
        else:
            # Call cstep to update the interval of uncertainty
            # and to compute the new step
            step_result = cstep(stepx, stepy, step, brackt, stmin, stmax,
                                safeguard_cubic=safeguard_cubic,
                                verbose=verbose)
            brackt = step_result['brackt']
            infoc = step_result['info']
            stepx = step_result['stepx']
            stepy = step_result['stepy']
            step = step_result['step']

        if not bracketed and brackt:
            bracketed = True
            if verbose:
                print("Bracketed")

        # Force a sufficient decrease in the size of the interval of uncertainty
        if brackt:
            width_new = abs(stepy.alpha - stepx.alpha)
            if width_new >= delta * width_old:
                if verbose:
                    print("Interval did not decrease sufficiently: bisecting")
                step.alpha = stepx.alpha + 0.5 * (stepy.alpha - stepx.alpha)
            width_old = width
            width = width_new


def more_thuente(c1: float = 1e-4, c2: float = 0.1,
                max_fn: float = np.inf, eps: float = 1e-6,
                alpha_max: float = np.inf,
                approx_armijo: bool = False,
                strong_curvature: bool = True,
                safeguard_cubic: bool = False,
                verbose: bool = False) -> Callable[..., Dict[str, Any]]:
    """
    Factory function that creates a More-Thuente line search function.

    Args:
        c1 (float, optional): Armijo condition constant. Defaults to 1e-4.
        c2 (float, optional): Curvature condition constant. Defaults to 0.1.
        max_fn (float, optional): Maximum number of function evaluations allowed. Defaults to infinity.
        eps (float, optional): Tolerance for approximate Armijo condition. Defaults to 1e-6.
        alpha_max (float, optional): Maximum acceptable step size. Defaults to infinity.
        approx_armijo (bool, optional): If True, uses an approximate Armijo condition. Defaults to False.
        strong_curvature (bool, optional): If True, checks the strong Wolfe curvature condition. Defaults to True.
        safeguard_cubic (bool, optional): If True, ensures the cubic step isn't too close to the bounds. Defaults to False.
        verbose (bool, optional): If True, prints debug information. Defaults to False.

    Returns:
        Callable[..., Dict[str, Any]]: A line search function that can be called with specific arguments.
    """
    if approx_armijo:
        armijo_check_fn = make_approx_armijo_ok_step(eps)
    else:
        armijo_check_fn = armijo_ok_step

    wolfe_ok_step_fn = make_wolfe_ok_step_fn(
        strong_curvature=strong_curvature,
        approx_armijo=approx_armijo,
        eps=eps
    )

    def line_search(phi: Callable[[float], Dict[str, float]],
                   step0: Step, alpha: float = 1.0,
                   total_max_fn: float = np.inf,
                   total_max_gr: float = np.inf,
                   total_max_fg: float = np.inf,
                   pm: Optional[Any] = None) -> Dict[str, Any]:
        """
        Executes the line search.

        Args:
            phi (Callable[[float], Dict[str, float]]): Objective function along the search direction.
            step0 (Step): Initial step before the line search.
            alpha (float, optional): Initial guess for the step size. Defaults to 1.0.
            total_max_fn (float, optional): Maximum function evaluations from external limits. Defaults to infinity.
            total_max_gr (float, optional): Maximum gradient evaluations from external limits. Defaults to infinity.
            total_max_fg (float, optional): Maximum function and gradient evaluations combined. Defaults to infinity.
            pm (Optional[Any], optional): Additional parameters (unused).

        Returns:
            Dict[str, Any]: Contains the best step found (`step`), number of function evaluations (`nfn`),
                            gradient evaluations (`ngr`), and termination info (`info` and `message`).
        """
        maxfev = min(max_fn, total_max_fn, total_max_gr, int(total_max_fg / 2))
        if maxfev <= 0:
            message = get_termination_message(3)
            return {'step': step0, 'nfn': 0, 'ngr': 0, 'info': 3, 'message': message}
        res = cvsrch(phi, step0,
                     alpha=alpha, c1=c1, c2=c2,
                     maxfev=maxfev, alpha_max=alpha_max,
                     armijo_check_fn=armijo_check_fn,
                     wolfe_ok_step_fn=wolfe_ok_step_fn,
                     safeguard_cubic=safeguard_cubic,
                     verbose=verbose)
        return {'step': res['step'], 'nfn': res['nfn'], 'ngr': res['nfn'], 'info': res['info'], 'message': get_termination_message(res['info'])}

    return line_search

In [2]:
if __name__ == "__main__":
    # Define the objective function and its derivative along the search direction
    def phi(alpha: float) -> Dict[str, float]:
        # Quadratic function f(alpha) = (x0 + alpha * p)^2
        x0 = 1.0  # Starting point
        p = -2.0  # Search direction (e.g., negative gradient)
        x = x0 + alpha * p
        f = x ** 2
        d = 2 * x * p  # Directional derivative
        return {'f': f, 'd': d}

    # Initialize step0
    initial_phi = phi(0.0)
    step0 = Step(alpha=0.0, f=initial_phi['f'], d=initial_phi['d'])

    # Create the line search function
    line_search = more_thuente(c1=1e-4, c2=0.9, max_fn=20, verbose=True)

    # Perform the line search
    result = line_search(
        phi=phi,
        step0=step0,
        alpha=1.0,
        total_max_fn=100,
        total_max_gr=100,
        total_max_fg=200,
        pm=None
    )

    # Access the results
    best_step = result['step']
    function_evals = result['nfn']
    termination_info = result['info']
    termination_message = result['message']

    print("\nLine Search Result:")
    print(f"Best step size: {best_step.alpha}")
    print(f"Function value at best step: {best_step.f}")
    print(f"Directional derivative at best step: {best_step.d}")
    print(f"Number of function evaluations: {function_evals}")
    print(f"Termination info code: {termination_info}")
    print(f"Termination Reason: {termination_message}")

Bracket: [0.0, 5.0] alpha = 1.0
cstep: Cubic interpolation failed, using quadratic interpolation.
cstep: Step too close to end point, weighted bisection
Bracketed
Bracket: [0.0, 1.0] alpha = 0.5000500100020004
Success! Step satisfies Wolfe conditions.
alpha = 0.5000500100020004

Line Search Result:
Best step size: 0.5000500100020004
Function value at best step: 1.0004001200337887e-08
Directional derivative at best step: 0.0004000800160035567
Number of function evaluations: 2
Termination info code: 1
Termination Reason: The sufficient decrease condition and the directional derivative condition hold.
