Skip to content

Commit

Permalink
Expose linesearch parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
j-towns committed Mar 3, 2016
1 parent faa803c commit 02a0dd7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 23 deletions.
21 changes: 10 additions & 11 deletions pymanopt/solvers/conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from pymanopt.solvers import linesearch as default_linesearchers
from pymanopt.solvers.linesearch import LineSearchAdaptive
from pymanopt.solvers.solver import Solver
from pymanopt import tools

Expand All @@ -21,7 +21,7 @@ class ConjugateGradient(Solver):
"""

def __init__(self, beta_type=BetaTypes.HestenesStiefel, orth_value=np.inf,
linesearch=None, *args, **kwargs):
linesearch=LineSearchAdaptive(), *args, **kwargs):
"""
Instantiate gradient solver class.
Variable attributes (defaults in brackets):
Expand All @@ -32,18 +32,15 @@ def __init__(self, beta_type=BetaTypes.HestenesStiefel, orth_value=np.inf,
Parameter for Powell's restart strategy. An infinite
value disables this strategy. See in code formula for
the specific criterion used.
- linesearch (None)
If None LineSearchAdaptive will be used.
- linesearch (LineSearchAdaptive)
The linesearch method to used.
"""
super(ConjugateGradient, self).__init__(*args, **kwargs)

self._beta_type = beta_type
self._orth_value = orth_value

if linesearch is None:
self._searcher = default_linesearchers.LineSearchAdaptive()
else:
self._searcher = linesearch
self.linesearch = linesearch

def solve(self, problem, x=None):
"""
Expand Down Expand Up @@ -71,6 +68,8 @@ def solve(self, problem, x=None):
objective = problem.cost
gradient = problem.grad

linesearch = self.linesearch

# If no starting point is specified, generate one at random.
if x is None:
x = man.rand()
Expand Down Expand Up @@ -98,7 +97,7 @@ def solve(self, problem, x=None):
self._start_optlog(extraiterfields=['gradnorm'],
solverparams={'beta_type': self._beta_type,
'orth_value': self._orth_value,
'linesearcher': self._searcher})
'linesearcher': linesearch})

while True:
if verbosity >= 2:
Expand Down Expand Up @@ -134,8 +133,8 @@ def solve(self, problem, x=None):
df0 = -gradPgrad

# Execute line search
stepsize, newx = self._searcher.search(objective, man, x, desc_dir,
cost, df0)
stepsize, newx = linesearch.search(objective, man, x, desc_dir,
cost, df0)

# Compute the new cost-related quantities for newx
newcost = objective(newx)
Expand Down
6 changes: 3 additions & 3 deletions pymanopt/solvers/linesearch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np


class LineSearch(object):
class LineSearchBackTracking(object):
"""
Line-search based on linesearch.m and linesearch_adaptive.m in the manopt
MATLAB package.
Back-tracking line-search based on linesearch.m in the manopt MATLAB
package.
"""

def __init__(self, contraction_factor=.5, optimism=2,
Expand Down
17 changes: 8 additions & 9 deletions pymanopt/solvers/steepest_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import time

from pymanopt.solvers import linesearch as default_linesearchers
from pymanopt.solvers.linesearch import LineSearchBackTracking
from pymanopt.solvers.solver import Solver


Expand All @@ -12,13 +12,10 @@ class SteepestDescent(Solver):
steepestdescent.m from the manopt MATLAB package.
"""

def __init__(self, linesearch=None, *args, **kwargs):
def __init__(self, linesearch=LineSearchBackTracking(), *args, **kwargs):
super(SteepestDescent, self).__init__(*args, **kwargs)

if linesearch is None:
self._searcher = default_linesearchers.LineSearch()
else:
self._searcher = linesearch
self.linesearch = linesearch

# Function to solve optimisation problem using steepest descent.
def solve(self, problem, x=None):
Expand Down Expand Up @@ -46,6 +43,8 @@ def solve(self, problem, x=None):
objective = problem.cost
gradient = problem.grad

linesearch = self.linesearch

# If no starting point is specified, generate one at random.
if x is None:
x = man.rand()
Expand All @@ -58,7 +57,7 @@ def solve(self, problem, x=None):
print(" iter\t\t cost val\t grad. norm")

self._start_optlog(extraiterfields=['gradnorm'],
solverparams={'linesearcher': self._searcher})
solverparams={'linesearcher': linesearch})

while True:
# Calculate new cost, grad and gradnorm
Expand All @@ -77,8 +76,8 @@ def solve(self, problem, x=None):
desc_dir = -grad

# Perform line-search
stepsize, x = self._searcher.search(objective, man, x, desc_dir,
cost, -gradnorm**2)
stepsize, x = linesearch(objective, man, x, desc_dir, cost,
-gradnorm**2)

stop_reason = self._check_stopping_criterion(
time0, stepsize=stepsize, gradnorm=gradnorm, iter=iter)
Expand Down

0 comments on commit 02a0dd7

Please sign in to comment.