Skip to content

Commit

Permalink
Merge pull request #59 from crealytics/master
Browse files Browse the repository at this point in the history
Use StrongWolfeLineSearch implementation in LBFGS
  • Loading branch information
dlwh committed Apr 24, 2013
2 parents 4f79aef + 970d9e1 commit 43d7e2c
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 2 deletions.
2 changes: 1 addition & 1 deletion learn/src/main/scala/breeze/optimize/LBFGS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class LBFGS[T](maxIter: Int = -1, m: Int=10, tolerance: Double=1E-9)
}

val ff = LineSearch.functionFromSearchDirection(f, x, dir)
val search = new BacktrackingLineSearch(shrinkStep = if(iter < 1) 0.01 else 0.5)
val search = new StrongWolfeLineSearch(maxZoomIter = 10, maxLineSearchIter = 10) // TODO: Need good default values here.
val alpha = search.minimize(ff, if(state.iter == 0.0) 1.0/norm(dir) else 1.0)

if(alpha * norm(grad) < 1E-10)
Expand Down
6 changes: 5 additions & 1 deletion learn/src/main/scala/breeze/optimize/LineSearch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package breeze.optimize

import breeze.math.InnerProductSpace

trait MinimizingLineSearch {
def minimize(f: DiffFunction[Double], init: Double = 1.0):Double
}

/**
* A line search optimizes a function of one variable without
* analytic gradient information. Differs only in whether or not it tries to find an exact minimizer
Expand All @@ -15,7 +19,7 @@ trait LineSearch extends ApproximateLineSearch
* backtracking line search), where there is no intrinsic termination criterion, only extrinsic
* @author dlwh
*/
trait ApproximateLineSearch {
trait ApproximateLineSearch extends MinimizingLineSearch {
final case class State(alpha: Double, value: Double, deriv: Double)
def iterations(f: DiffFunction[Double], init: Double = 1.0):Iterator[State]
def minimize(f: DiffFunction[Double], init: Double = 1.0):Double = iterations(f, init).reduceLeft( (a,b) => b).alpha
Expand Down
186 changes: 186 additions & 0 deletions learn/src/main/scala/breeze/optimize/StrongWolfe.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package breeze.optimize

import breeze.linalg._
import breeze.numerics._
import com.typesafe.scalalogging.log4j.Logging

abstract class CubicLineSearch extends Logging with MinimizingLineSearch {
import logger._
import scala.math._

case class Bracket(
t: Double, // 1d line search parameter
dd: Double, // Directional Derivative at t
fval: Double // Function value at t
)

/*
* Invoke line search, returning stepsize
*/
def minimize(f: DiffFunction[Double], init: Double = 1.0): Double
/*
* Cubic interpolation to find the minimum inside the bracket l and r.
* Uses the fval and gradient at the left and right side, which gives
* the four bits of information required to interpolate a cubic.
* This is additionally "safe-guarded" whereby steps too close to
* either side of the interval will not be selected.
*/
def interp(l: Bracket, r: Bracket) = {
// See N&W p57 actual for an explanation of the math
val d1 = l.dd + r.dd - 3 * (l.fval - r.fval) / (l.t - r.t)
val d2 = sqrt(d1 * d1 - l.dd * r.dd)
val multipler = r.t - l.t
val t = r.t - multipler * (r.dd + d2 - d1) / (r.dd - l.dd + 2 * d2)

// If t is too close to either end bracket, move it closer to the middle

val lbound = l.t + 0.1 * (r.t - l.t)
val ubound = l.t + 0.9 * (r.t - l.t)
t match {
case _ if t < lbound =>
debug("Cubic " + t + " below LHS limit: " + lbound)
lbound
case _ if t > ubound =>
debug("Cubic " + t + " above RHS limit: " + ubound)
ubound
case _ => t
}
}
}

/*
* This line search will attempt steps larger than step length one,
* unlike back-tracking line searches. It also comes with strong convergence
* properties. It selects step lengths using cubic interpolation, which
* works better than other approaches in my experience.
* Based on Nocedal & Wright.
*/
class StrongWolfeLineSearch(maxZoomIter: Int, maxLineSearchIter: Int) extends CubicLineSearch {
import logger._
import scala.math._

val c1 = 1e-4
val c2 = 0.9

/**
* Performs a line search on the function f, returning a point satisfying
* the Strong Wolfe conditions. Based on the line search detailed in
* Nocedal & Wright Numerical Optimization p58.
*/
def minimize(f: DiffFunction[Double], init: Double = 1.0):Double = {

def phi(t: Double): Bracket = {
val (pval, pdd) = f.calculate(t)
Bracket(t = t, dd = pdd, fval = pval)
}

var t = init // Search's current multiple of pk
var low = phi(0.0)
val fval = low.fval
val dd = low.dd

if (dd > 0) {
throw new FirstOrderException("Line search invoked with non-descent direction: " + dd)
}

/**
* Assuming a point satisfying the strong wolfe conditions exists within
* the passed interval, this method finds it by iteratively refining the
* interval. Nocedal & Wright give the following invariants for zoom's loop:
*
* - The interval bounded by low.t and hi.t contains a point satisfying the
* strong Wolfe conditions.
* - Among all points evaluated so far that satisfy the "sufficient decrease"
* condition, low.t is the one with the smallest fval.
* - hi.t is chosen so that low.dd * (hi.t - low.t) < 0.
*/
def zoom(linit: Bracket, rinit: Bracket): Double = {

var low = linit
var hi = rinit

for (i <- 0 until maxZoomIter) {
// Interp assumes left less than right in t value, so flip if needed
val t = if (low.t > hi.t) interp(hi, low) else interp(low, hi)

// Evaluate objective at t, and build bracket
val c = phi(t)
//debug("ZOOM:\n c: " + c + " \n l: " + low + " \nr: " + hi)
info("Line search t: " + t + " fval: " + c.fval +
" rhs: " + (fval + c1 * c.t * dd) + " cdd: " + c.dd)

///////////////
/// Update left or right bracket, or both

if (c.fval > fval + c1 * c.t * dd || c.fval >= low.fval) {
// "Sufficient decrease" condition not satisfied by c. Shrink interval at right
hi = c
debug("hi=c")
} else {

// Zoom exit condition is the "curvature" condition
// Essentially that the directional derivative is large enough
if (abs(c.dd) <= c2 * abs(dd)) {
return c.t
}

// If the signs don't coincide, flip left to right before updating l to c
if (c.dd * (hi.t - low.t) >= 0) {
debug("flipping")
hi = low
}

debug("low=c")
// If curvature condition not satisfied, move the left hand side of the
// interval further away from t=0.
low = c
}
}

throw new FirstOrderException(s"Line search zoom failed")
}

///////////////////////////////////////////////////////////////////

for (i <- 0 until maxLineSearchIter) {
val c = phi(t)

// If phi has a bounded domain, inf or nan usually indicates we took
// too large a step.
if (java.lang.Double.isInfinite(c.fval) || java.lang.Double.isNaN(c.fval)) {
t /= 2.0
error("Encountered bad values in function evaluation. Decreasing step size to " + t)
} else {

// Zoom if "sufficient decrease" condition is not satisfied
if ((c.fval > fval + c1 * t * dd) ||
(c.fval >= low.fval && i > 0)) {
debug("Line search t: " + t + " fval: " + c.fval + " cdd: " + c.dd)
return zoom(low, c)
}

// We don't need to zoom at all
// if the strong wolfe condition is satisfied already.
if (abs(c.dd) <= c2 * abs(dd)) {
return c.t
}

// If c.dd is positive, we zoom on the inverted interval.
// Occurs if we skipped over the nearest local minimum
// over to the next one.
if (c.dd >= 0) {
debug("Line search t: " + t + " fval: " + c.fval +
" rhs: " + (fval + c1 * t * dd) + " cdd: " + c.dd)
return zoom(c, low)
}

low = c
t *= 1.5
debug("Sufficent Decrease condition but not curvature condition satisfied. Increased t to: " + t)
}
}

throw new FirstOrderException("Line search failed")
}

}

0 comments on commit 43d7e2c

Please sign in to comment.