Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update the lbfgs optimizer #2337

Merged
merged 3 commits into from Jun 28, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/shogun/mathematics/Math.cpp
Expand Up @@ -364,10 +364,10 @@ bool CMath::strtold(const char* str, floatmax_t* long_double_result)
float64_t CMath::get_abs_tolorance(float64_t true_value, float64_t rel_tolorance)
{
REQUIRE(rel_tolorance > 0 && rel_tolorance < 1.0,
"Relative tolorance should be less than 1.0 and positive");
"Relative tolorance (%f) should be less than 1.0 and positive\n", rel_tolorance);
REQUIRE(is_finite(true_value),
"The true_value should be finite");
float64_t abs_tolorance = 0.0;
"The true_value should be finite\n");
float64_t abs_tolorance = rel_tolorance;
if (abs(true_value)>0.0)
{
if (log(abs(true_value)) + log(rel_tolorance) < log(F_MIN_VAL64))
Expand Down
28 changes: 19 additions & 9 deletions src/shogun/optimization/lbfgs/lbfgs.cpp
Expand Up @@ -83,6 +83,7 @@ struct tag_callback_data {
void *instance;
lbfgs_evaluate_t proc_evaluate;
lbfgs_progress_t proc_progress;
lbfgs_adjust_step_t proc_adjust_step;
};
typedef struct tag_callback_data callback_data_t;

Expand Down Expand Up @@ -210,7 +211,8 @@ int32_t lbfgs(
lbfgs_evaluate_t proc_evaluate,
lbfgs_progress_t proc_progress,
void *instance,
lbfgs_parameter_t *_param
lbfgs_parameter_t *_param,
lbfgs_adjust_step_t proc_adjust_step
)
{
int32_t ret;
Expand All @@ -237,6 +239,7 @@ int32_t lbfgs(
cd.instance = instance;
cd.proc_evaluate = proc_evaluate;
cd.proc_progress = proc_progress;
cd.proc_adjust_step=proc_adjust_step;

/* Check the input parameters for errors. */
if (n <= 0) {
Expand Down Expand Up @@ -631,8 +634,11 @@ static int32_t line_search_backtracking(
dgtest = param->ftol * dginit;

for (;;) {
std::copy(xp,xp+n,x);
SGVector<float64_t>::add(x, 1, x, *stp, s, n);
std::copy(xp,xp+n,x);
if (cd->proc_adjust_step)
*stp=cd->proc_adjust_step(cd->instance, x, s, cd->n, *stp);

SGVector<float64_t>::add(x, 1, x, *stp, s, n);

/* Evaluate the function and gradient values. */
*f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
Expand Down Expand Up @@ -717,8 +723,10 @@ static int32_t line_search_backtracking_owlqn(

for (;;) {
/* Update the current point. */
std::copy(xp,xp+n,x);
SGVector<float64_t>::add(x, 1, x, *stp, s, n);
std::copy(xp,xp+n,x);
if (cd->proc_adjust_step)
*stp=cd->proc_adjust_step(cd->instance, x, s, cd->n, *stp);
SGVector<float64_t>::add(x, 1, x, *stp, s, n);

/* The current point is projected onto the orthant. */
owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end);
Expand Down Expand Up @@ -791,7 +799,7 @@ static int32_t line_search_morethuente(
}

/* Compute the initial gradient in the search direction. */
dginit = SGVector<float64_t>::dot(g, s, n);
dginit = SGVector<float64_t>::dot(g, s, n);

/* Make sure that s points to a descent direction. */
if (0 < dginit) {
Expand Down Expand Up @@ -848,12 +856,14 @@ static int32_t line_search_morethuente(
Compute the current value of x:
x <- x + (*stp) * s.
*/
std::copy(xp,xp+n,x);
SGVector<float64_t>::add(x, 1, x, *stp, s, n);
std::copy(xp,xp+n,x);
if (cd->proc_adjust_step)
*stp=cd->proc_adjust_step(cd->instance, x, s, cd->n, *stp);
SGVector<float64_t>::add(x, 1, x, *stp, s, n);

/* Evaluate the function and gradient values. */
*f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
dg = SGVector<float64_t>::dot(g, s, n);
dg = SGVector<float64_t>::dot(g, s, n);

ftest1 = finit + *stp * dgtest;
++count;
Expand Down
31 changes: 30 additions & 1 deletion src/shogun/optimization/lbfgs/lbfgs.h
Expand Up @@ -391,6 +391,31 @@ typedef int (*lbfgs_progress_t)(
int ls
);

/**
* Callback interface to adjust step size based on constraints.
*
* If the function pointer is not NULL, the lbfgs() function calls
* this function to adjust its step size. A client program can implement this function
* to adjust the step size used in lbfgs update based on user-defined constraints.
* Note that the update is x_new = x + step * d, where step is positive.
*
* @param instance The user data sent for lbfgs() function by the client.
* @param x The current values of variables.
* @param d The direction vector of variables.
* @param n The number of variables.
* @param step The current step of the line search routine.
*
* @retval float64_t The value of adjusted step size
*/

typedef float64_t (*lbfgs_adjust_step_t)(
void *instance,
const float64_t *x,
const float64_t *d,
const int n,
const float64_t step
);

/*
A user must implement a function compatible with ::lbfgs_evaluate_t (evaluation
callback) and pass the pointer to the callback function to lbfgs() arguments.
Expand Down Expand Up @@ -445,6 +470,9 @@ In this formula, ||.|| denotes the Euclidean norm.
* parameter to \c NULL to use the default parameters.
* Call lbfgs_parameter_init() function to fill a
* structure with the default values.
* @param proc_adjust_step The callback function to adjust step size based on constraints.
* This argument can be set to \c NULL if there is not constraint.
*
* @retval int The status code. This function returns zero if the
* minimization process terminates without an error. A
* non-zero value indicates an error.
Expand All @@ -456,7 +484,8 @@ int lbfgs(
lbfgs_evaluate_t proc_evaluate,
lbfgs_progress_t proc_progress,
void *instance,
lbfgs_parameter_t *param
lbfgs_parameter_t *param,
lbfgs_adjust_step_t proc_adjust_step=NULL
);

/**
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/mathematics/Math_unittest.cc
Expand Up @@ -367,7 +367,7 @@ TEST(CMath, fequals_close_to_zero)

TEST(CMath, get_abs_tolorance)
{
EXPECT_EQ(CMath::get_abs_tolorance(0.0, 0.01), 0.0);
EXPECT_EQ(CMath::get_abs_tolorance(0.0, 0.01), 0.01);
EXPECT_NEAR(CMath::get_abs_tolorance(-0.01, 0.01), 0.0001, 1E-15);
EXPECT_NEAR(CMath::get_abs_tolorance(-9.5367431640625e-7, 0.01), 9.5367431640625e-9, 1E-15);
EXPECT_NEAR(CMath::get_abs_tolorance(9.5367431640625e-7, 0.01), 9.5367431640625e-9, 1E-15);
Expand Down