Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
108 lines (87 sloc) 3.8 KB
from __future__ import print_function, division
import time
from copy import deepcopy
from pymanopt.solvers.linesearch import LineSearchBackTracking
from pymanopt.solvers.solver import Solver
class SteepestDescent(Solver):
"""
Steepest descent (gradient descent) algorithm based on
steepestdescent.m from the manopt MATLAB package.
"""
def __init__(self, linesearch=LineSearchBackTracking(), *args, **kwargs):
super(SteepestDescent, self).__init__(*args, **kwargs)
if linesearch is None:
self._linesearch = LineSearchBackTracking()
else:
self._linesearch = linesearch
self.linesearch = None
# Function to solve optimisation problem using steepest descent.
def solve(self, problem, x=None, reuselinesearch=False):
"""
Perform optimization using gradient descent with linesearch.
This method first computes the gradient (derivative) of obj
w.r.t. arg, and then optimizes by moving in the direction of
steepest descent (which is the opposite direction to the gradient).
Arguments:
- problem
Pymanopt problem setup using the Problem class, this must
have a .manifold attribute specifying the manifold to optimize
over, as well as a cost and enough information to compute
the gradient of that cost.
- x=None
Optional parameter. Starting point on the manifold. If none
then a starting point will be randomly generated.
- reuselinesearch=False
Whether to reuse the previous linesearch object. Allows to
use information from a previous solve run.
Returns:
- x
Local minimum of obj, or if algorithm terminated before
convergence x will be the point at which it terminated.
"""
man = problem.manifold
verbosity = problem.verbosity
objective = problem.cost
gradient = problem.grad
if not reuselinesearch or self.linesearch is None:
self.linesearch = deepcopy(self._linesearch)
linesearch = self.linesearch
# If no starting point is specified, generate one at random.
if x is None:
x = man.rand()
# Initialize iteration counter and timer
iter = 0
time0 = time.time()
if verbosity >= 2:
print(" iter\t\t cost val\t grad. norm")
self._start_optlog(extraiterfields=['gradnorm'],
solverparams={'linesearcher': linesearch})
while True:
# Calculate new cost, grad and gradnorm
cost = objective(x)
grad = gradient(x)
gradnorm = man.norm(x, grad)
iter = iter + 1
if verbosity >= 2:
print("%5d\t%+.16e\t%.8e" % (iter, cost, gradnorm))
if self._logverbosity >= 2:
self._append_optlog(iter, x, cost, gradnorm=gradnorm)
# Descent direction is minus the gradient
desc_dir = -grad
# Perform line-search
stepsize, x = linesearch.search(objective, man, x, desc_dir,
cost, -gradnorm**2)
stop_reason = self._check_stopping_criterion(
time0, stepsize=stepsize, gradnorm=gradnorm, iter=iter)
if stop_reason:
if verbosity >= 1:
print(stop_reason)
print('')
break
if self._logverbosity <= 0:
return x
else:
self._stop_optlog(x, objective(x), stop_reason, time0,
stepsize=stepsize, gradnorm=gradnorm,
iter=iter)
return x, self._optlog