Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use StrongWolfeLineSearch implementation in LBFGS #59

Merged
merged 3 commits into from
Apr 24, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion learn/src/main/scala/breeze/optimize/LBFGS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class LBFGS[T](maxIter: Int = -1, m: Int=10, tolerance: Double=1E-9)
}

val ff = LineSearch.functionFromSearchDirection(f, x, dir)
val search = new BacktrackingLineSearch(shrinkStep = if(iter < 1) 0.01 else 0.5)
val search = new StrongWolfeLineSearch(maxZoomIter = 10, maxLineSearchIter = 10) // TODO: Need good default values here.
val alpha = search.minimize(ff, if(state.iter == 0.0) 1.0/norm(dir) else 1.0)

if(alpha * norm(grad) < 1E-10)
Expand Down
6 changes: 5 additions & 1 deletion learn/src/main/scala/breeze/optimize/LineSearch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package breeze.optimize

import breeze.math.InnerProductSpace

trait MinimizingLineSearch {
def minimize(f: DiffFunction[Double], init: Double = 1.0):Double
}

/**
* A line search optimizes a function of one variable without
* analytic gradient information. Differs only in whether or not it tries to find an exact minimizer
Expand All @@ -15,7 +19,7 @@ trait LineSearch extends ApproximateLineSearch
* backtracking line search), where there is no intrinsic termination criterion, only extrinsic
* @author dlwh
*/
trait ApproximateLineSearch {
trait ApproximateLineSearch extends MinimizingLineSearch {
final case class State(alpha: Double, value: Double, deriv: Double)
def iterations(f: DiffFunction[Double], init: Double = 1.0):Iterator[State]
def minimize(f: DiffFunction[Double], init: Double = 1.0):Double = iterations(f, init).reduceLeft( (a,b) => b).alpha
Expand Down
186 changes: 186 additions & 0 deletions learn/src/main/scala/breeze/optimize/StrongWolfe.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package breeze.optimize

import breeze.linalg._
import breeze.numerics._
import com.typesafe.scalalogging.log4j.Logging

abstract class CubicLineSearch extends Logging with MinimizingLineSearch {
import logger._
import scala.math._

case class Bracket(
t: Double, // 1d line search parameter
dd: Double, // Directional Derivative at t
fval: Double // Function value at t
)

/*
* Invoke line search, returning stepsize
*/
def minimize(f: DiffFunction[Double], init: Double = 1.0): Double
/*
* Cubic interpolation to find the minimum inside the bracket l and r.
* Uses the fval and gradient at the left and right side, which gives
* the four bits of information required to interpolate a cubic.
* This is additionally "safe-guarded" whereby steps too close to
* either side of the interval will not be selected.
*/
def interp(l: Bracket, r: Bracket) = {
// See N&W p57 actual for an explanation of the math
val d1 = l.dd + r.dd - 3 * (l.fval - r.fval) / (l.t - r.t)
val d2 = sqrt(d1 * d1 - l.dd * r.dd)
val multipler = r.t - l.t
val t = r.t - multipler * (r.dd + d2 - d1) / (r.dd - l.dd + 2 * d2)

// If t is too close to either end bracket, move it closer to the middle

val lbound = l.t + 0.1 * (r.t - l.t)
val ubound = l.t + 0.9 * (r.t - l.t)
t match {
case _ if t < lbound =>
debug("Cubic " + t + " below LHS limit: " + lbound)
lbound
case _ if t > ubound =>
debug("Cubic " + t + " above RHS limit: " + ubound)
ubound
case _ => t
}
}
}

/*
* This line search will attempt steps larger than step length one,
* unlike back-tracking line searches. It also comes with strong convergence
* properties. It selects step lengths using cubic interpolation, which
* works better than other approaches in my experience.
* Based on Nocedal & Wright.
*/
class StrongWolfeLineSearch(maxZoomIter: Int, maxLineSearchIter: Int) extends CubicLineSearch {
import logger._
import scala.math._

val c1 = 1e-4
val c2 = 0.9

/**
* Performs a line search on the function f, returning a point satisfying
* the Strong Wolfe conditions. Based on the line search detailed in
* Nocedal & Wright Numerical Optimization p58.
*/
def minimize(f: DiffFunction[Double], init: Double = 1.0):Double = {

def phi(t: Double): Bracket = {
val (pval, pdd) = f.calculate(t)
Bracket(t = t, dd = pdd, fval = pval)
}

var t = init // Search's current multiple of pk
var low = phi(0.0)
val fval = low.fval
val dd = low.dd

if (dd > 0) {
throw new FirstOrderException("Line search invoked with non-descent direction: " + dd)
}

/**
* Assuming a point satisfying the strong wolfe conditions exists within
* the passed interval, this method finds it by iteratively refining the
* interval. Nocedal & Wright give the following invariants for zoom's loop:
*
* - The interval bounded by low.t and hi.t contains a point satisfying the
* strong Wolfe conditions.
* - Among all points evaluated so far that satisfy the "sufficient decrease"
* condition, low.t is the one with the smallest fval.
* - hi.t is chosen so that low.dd * (hi.t - low.t) < 0.
*/
def zoom(linit: Bracket, rinit: Bracket): Double = {

var low = linit
var hi = rinit

for (i <- 0 until maxZoomIter) {
// Interp assumes left less than right in t value, so flip if needed
val t = if (low.t > hi.t) interp(hi, low) else interp(low, hi)

// Evaluate objective at t, and build bracket
val c = phi(t)
//debug("ZOOM:\n c: " + c + " \n l: " + low + " \nr: " + hi)
info("Line search t: " + t + " fval: " + c.fval +
" rhs: " + (fval + c1 * c.t * dd) + " cdd: " + c.dd)

///////////////
/// Update left or right bracket, or both

if (c.fval > fval + c1 * c.t * dd || c.fval >= low.fval) {
// "Sufficient decrease" condition not satisfied by c. Shrink interval at right
hi = c
debug("hi=c")
} else {

// Zoom exit condition is the "curvature" condition
// Essentially that the directional derivative is large enough
if (abs(c.dd) <= c2 * abs(dd)) {
return c.t
}

// If the signs don't coincide, flip left to right before updating l to c
if (c.dd * (hi.t - low.t) >= 0) {
debug("flipping")
hi = low
}

debug("low=c")
// If curvature condition not satisfied, move the left hand side of the
// interval further away from t=0.
low = c
}
}

throw new FirstOrderException(s"Line search zoom failed")
}

///////////////////////////////////////////////////////////////////

for (i <- 0 until maxLineSearchIter) {
val c = phi(t)

// If phi has a bounded domain, inf or nan usually indicates we took
// too large a step.
if (java.lang.Double.isInfinite(c.fval) || java.lang.Double.isNaN(c.fval)) {
t /= 2.0
error("Encountered bad values in function evaluation. Decreasing step size to " + t)
} else {

// Zoom if "sufficient decrease" condition is not satisfied
if ((c.fval > fval + c1 * t * dd) ||
(c.fval >= low.fval && i > 0)) {
debug("Line search t: " + t + " fval: " + c.fval + " cdd: " + c.dd)
return zoom(low, c)
}

// We don't need to zoom at all
// if the strong wolfe condition is satisfied already.
if (abs(c.dd) <= c2 * abs(dd)) {
return c.t
}

// If c.dd is positive, we zoom on the inverted interval.
// Occurs if we skipped over the nearest local minimum
// over to the next one.
if (c.dd >= 0) {
debug("Line search t: " + t + " fval: " + c.fval +
" rhs: " + (fval + c1 * t * dd) + " cdd: " + c.dd)
return zoom(c, low)
}

low = c
t *= 1.5
debug("Sufficent Decrease condition but not curvature condition satisfied. Increased t to: " + t)
}
}

throw new FirstOrderException("Line search failed")
}

}