From c3f908ede7a5756c1d9ed3da06fe1d09f12c70c6 Mon Sep 17 00:00:00 2001 From: Wu Lin Date: Tue, 24 Jun 2014 23:26:53 -0400 Subject: [PATCH] update the lbfgs optimizer --- src/shogun/optimization/lbfgs/lbfgs.cpp | 28 +++++++++++++++------- src/shogun/optimization/lbfgs/lbfgs.h | 31 ++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/shogun/optimization/lbfgs/lbfgs.cpp b/src/shogun/optimization/lbfgs/lbfgs.cpp index 6f2a79da0fb..3af3c11e936 100644 --- a/src/shogun/optimization/lbfgs/lbfgs.cpp +++ b/src/shogun/optimization/lbfgs/lbfgs.cpp @@ -83,6 +83,7 @@ struct tag_callback_data { void *instance; lbfgs_evaluate_t proc_evaluate; lbfgs_progress_t proc_progress; + lbfgs_adjust_step_t proc_adjust_step; }; typedef struct tag_callback_data callback_data_t; @@ -210,7 +211,8 @@ int32_t lbfgs( lbfgs_evaluate_t proc_evaluate, lbfgs_progress_t proc_progress, void *instance, - lbfgs_parameter_t *_param + lbfgs_parameter_t *_param, + lbfgs_adjust_step_t proc_adjust_step ) { int32_t ret; @@ -237,6 +239,7 @@ int32_t lbfgs( cd.instance = instance; cd.proc_evaluate = proc_evaluate; cd.proc_progress = proc_progress; + cd.proc_adjust_step=proc_adjust_step; /* Check the input parameters for errors. */ if (n <= 0) { @@ -631,8 +634,11 @@ static int32_t line_search_backtracking( dgtest = param->ftol * dginit; for (;;) { - std::copy(xp,xp+n,x); - SGVector::add(x, 1, x, *stp, s, n); + std::copy(xp,xp+n,x); + if (cd->proc_adjust_step) + *stp=cd->proc_adjust_step(cd->instance, x, s, cd->n, *stp); + + SGVector::add(x, 1, x, *stp, s, n); /* Evaluate the function and gradient values. */ *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp); @@ -717,8 +723,10 @@ static int32_t line_search_backtracking_owlqn( for (;;) { /* Update the current point. */ - std::copy(xp,xp+n,x); - SGVector::add(x, 1, x, *stp, s, n); + std::copy(xp,xp+n,x); + if (cd->proc_adjust_step) + *stp=cd->proc_adjust_step(cd->instance, x, s, cd->n, *stp); + SGVector::add(x, 1, x, *stp, s, n); /* The current point is projected onto the orthant. */ owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end); @@ -791,7 +799,7 @@ static int32_t line_search_morethuente( } /* Compute the initial gradient in the search direction. */ - dginit = SGVector::dot(g, s, n); + dginit = SGVector::dot(g, s, n); /* Make sure that s points to a descent direction. */ if (0 < dginit) { @@ -848,12 +856,14 @@ static int32_t line_search_morethuente( Compute the current value of x: x <- x + (*stp) * s. */ - std::copy(xp,xp+n,x); - SGVector::add(x, 1, x, *stp, s, n); + std::copy(xp,xp+n,x); + if (cd->proc_adjust_step) + *stp=cd->proc_adjust_step(cd->instance, x, s, cd->n, *stp); + SGVector::add(x, 1, x, *stp, s, n); /* Evaluate the function and gradient values. */ *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp); - dg = SGVector::dot(g, s, n); + dg = SGVector::dot(g, s, n); ftest1 = finit + *stp * dgtest; ++count; diff --git a/src/shogun/optimization/lbfgs/lbfgs.h b/src/shogun/optimization/lbfgs/lbfgs.h index fb40af1b845..f779dd484a5 100644 --- a/src/shogun/optimization/lbfgs/lbfgs.h +++ b/src/shogun/optimization/lbfgs/lbfgs.h @@ -391,6 +391,31 @@ typedef int (*lbfgs_progress_t)( int ls ); +/** + * Callback interface to adjust step size based on constrains. + * + * If the function pointer is not NULL, the lbfgs() function call + * this function to adjust its step size. A client program can implement this function + * to adjust the step size used in lbfgs update based on user-defined constrains. + * Note that the update is x_new = x + step * d, where step is positive. + * + * @param instance The user data sent for lbfgs() function by the client. + * @param x The current values of variables. + * @param d The direction vector of variables. + * @param n The number of variables. + * @param step The current step of the line search routine. + * + * @retval float64_t The value of adjusted step size + */ + +typedef float64_t (*lbfgs_adjust_step_t)( + void *instance, + const float64_t *x, + const float64_t *d, + const int n, + const float64_t step + ); + /* A user must implement a function compatible with ::lbfgs_evaluate_t (evaluation callback) and pass the pointer to the callback function to lbfgs() arguments. @@ -445,6 +470,9 @@ In this formula, ||.|| denotes the Euclidean norm. * parameter to \c NULL to use the default parameters. * Call lbfgs_parameter_init() function to fill a * structure with the default values. + * @param proc_adjust_step The callback function to adjust step size based on constrains. + * This argument can be set to \c NULL if there is not constrain. + * * @retval int The status code. This function returns zero if the * minimization process terminates without an error. A * non-zero value indicates an error. @@ -456,7 +484,8 @@ int lbfgs( lbfgs_evaluate_t proc_evaluate, lbfgs_progress_t proc_progress, void *instance, - lbfgs_parameter_t *param + lbfgs_parameter_t *param, + lbfgs_adjust_step_t proc_adjust_step=NULL ); /**