Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add minimization method to Utils module
- Loading branch information
Showing
5 changed files
with
176 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ void Init_rumale(void) | |
mRumale = rb_define_module("Rumale"); | ||
|
||
init_tree_module(); | ||
init_utils_module(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,5 +4,6 @@ | |
#include <ruby.h> | ||
|
||
#include "tree.h" | ||
#include "utils.h" | ||
|
||
#endif /* RUMALE_H */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#include "utils.h" | ||
|
||
RUBY_EXTERN VALUE mRumale; | ||
|
||
#define SIGMA_INIT 1e-4 | ||
#define BETA_MIN 1e-15 | ||
#define BETA_MAX 1e+15 | ||
|
||
/** | ||
* @!visibility private | ||
* The multivariate optimization with the scaled conjugate gradient method. | ||
* | ||
* *Reference* | ||
* 1. Moller, M F., "A Scaled Conjugate Gradient Algorithm for Fast Supervised Learning," Neural Networks, Vol. 6, pp. 525--533, 1993. | ||
* | ||
* @overload minimize(fnc, x, args, max_iter, xtol, ftol, jtol, logger) | ||
* @param fnc [Method/Proc] The method for calculating the objective function to be minimized and its gradient vector. | ||
* @param x [Numo::DFloat] (shape: n_dimensions) The initial points. | ||
* @param args [Array] The arguments pass to the 'fnc'. | ||
* @param max_iter [Integer] The maximum number of iterations. | ||
* @param xtol [Float] The tolerance for termination for the optimal vector norm. | ||
* @param ftol [Float] The tolerance for termination for the objective function value. | ||
* @param jtol [Float] The tolerance for termination for the gradient norm. | ||
* @param logger [Method/Proc/Nil] The method for logging. | ||
* | ||
* @return [Hash] { x:, n_iter: } The minimization resuts. | ||
*/ | ||
static VALUE | ||
minimize(VALUE self, VALUE fnc, VALUE x, VALUE args, VALUE _max_iter, VALUE _xtol, VALUE _ftol, VALUE _jtol, VALUE logger) | ||
{ | ||
const int max_iter = NUM2INT(_max_iter); | ||
const double xtol = NUM2DBL(_xtol); | ||
const double ftol = NUM2DBL(_ftol); | ||
const double jtol = NUM2DBL(_jtol); | ||
const int n_dimensions = NUM2INT(rb_funcall(x, rb_intern("size"), 0)); | ||
int n_iter = 0; | ||
char success = 1; | ||
int n_successes = 0; | ||
VALUE ret = rb_funcall(fnc, rb_intern("call"), 2, x, args); | ||
double f_prev = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc")))); | ||
double f_next = 0.0; | ||
double f_curr = f_prev; | ||
VALUE j_prev = rb_hash_aref(ret, ID2SYM(rb_intern("jcb"))); | ||
VALUE j_next = j_prev; | ||
VALUE d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1)); | ||
double alpha = 0.0; | ||
double beta = 1.0; | ||
double gamma = 0.0; | ||
double delta = 0.0; | ||
double ddelta = 0.0; | ||
double theta = 0.0; | ||
double kappa = 0.0; | ||
double mu = 0.0; | ||
double sigma = 0.0; | ||
VALUE x_next = Qnil; | ||
VALUE x_plus = Qnil; | ||
VALUE j_plus = Qnil; | ||
VALUE result = rb_hash_new(); | ||
|
||
for (n_iter = 0; n_iter < max_iter; n_iter++) { | ||
if (success == 1) { | ||
mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next)); | ||
if (mu >= 0.0) { | ||
d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1)); | ||
mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next)); | ||
} | ||
kappa = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, d)); | ||
if (kappa < 1e-16) break; | ||
|
||
sigma = SIGMA_INIT / sqrt(kappa); | ||
x_plus = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(sigma))); | ||
ret = rb_funcall(fnc, rb_intern("call"), 2, x_plus, args); | ||
j_plus = rb_hash_aref(ret, ID2SYM(rb_intern("jcb"))); | ||
theta = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, rb_funcall(j_plus, rb_intern("-"), 1, j_next))) / sigma; | ||
} | ||
|
||
delta = theta + beta * kappa; | ||
if (delta <= 0) { | ||
delta = beta * kappa; | ||
beta -= theta / kappa; | ||
} | ||
alpha = -mu / delta; | ||
|
||
x_next = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha))); | ||
ret = rb_funcall(fnc, rb_intern("call"), 2, x_next, args); | ||
f_next = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc")))); | ||
|
||
if (isinf(f_next) || isnan(f_next)) break; | ||
|
||
ddelta = 2 * (f_next - f_prev) / (alpha * mu); | ||
if (ddelta >= 0) { | ||
success = 1; | ||
n_successes++; | ||
x = x_next; | ||
f_curr = f_next; | ||
} else { | ||
success = 0; | ||
f_curr = f_prev; | ||
} | ||
|
||
if (!NIL_P(logger)) rb_funcall(logger, rb_intern("call"), 2, INT2NUM(n_iter + 1), DBL2NUM(f_curr)); | ||
|
||
if (success == 1) { | ||
if (fabs(f_next - f_prev) < ftol) break; | ||
if (NUM2DBL(rb_funcall(rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha)), rb_intern("abs"), 0), rb_intern("max"), 0)) < xtol) break; | ||
|
||
f_prev = f_next; | ||
j_prev = j_next; | ||
j_next = rb_hash_aref(ret, ID2SYM(rb_intern("jcb"))); | ||
|
||
if (NUM2DBL(rb_funcall(j_next, rb_intern("dot"), 1, j_next)) < jtol) break; | ||
} | ||
|
||
if (ddelta < 0.25) beta = beta * 4 < BETA_MAX ? beta * 4 : BETA_MAX; | ||
if (ddelta > 0.75) beta = beta / 4 > BETA_MIN ? beta / 4 : BETA_MIN; | ||
|
||
if (n_successes == n_dimensions) { | ||
d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1)); | ||
n_successes = 0; | ||
} else if (success == 1) { | ||
gamma = NUM2DBL(rb_funcall(rb_funcall(j_prev, rb_intern("-"), 1, j_next), rb_intern("dot"), 1, j_next)) / mu; | ||
d = rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(gamma)), rb_intern("-"), 1, j_next); | ||
} | ||
} | ||
|
||
rb_hash_aset(result, ID2SYM(rb_intern("x")), x); | ||
rb_hash_aset(result, ID2SYM(rb_intern("n_iter")), INT2NUM(n_iter)); | ||
|
||
return result; | ||
} | ||
|
||
void init_utils_module() | ||
{ | ||
VALUE mUtils = rb_define_module_under(mRumale, "Utils"); | ||
|
||
rb_define_module_function(mUtils, "minimize", minimize, 8); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#ifndef RUMALE_UTILS_H | ||
#define RUMALE_UTILS_H 1 | ||
|
||
#include <math.h> | ||
#include <ruby.h> | ||
#include <numo/narray.h> | ||
#include <numo/template.h> | ||
|
||
void init_utils_module(); | ||
|
||
#endif /* RUMALE_UTILS_H */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# frozen_string_literal: true | ||
|
||
RSpec.describe Rumale::Utils do | ||
describe '#minimize' do | ||
let(:x) { Numo::DFloat.zeros(3) } | ||
let(:f) { Numo::DFloat[[1, 1, 1], [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 0]] } | ||
let(:k) { Numo::DFloat[1.0, 0.3, 0.5] } | ||
|
||
let(:fnc) do | ||
proc do |x, args| | ||
f, k = args | ||
log_pdot = f.dot(x) | ||
log_z = Math.log(Numo::NMath.exp(log_pdot).sum) | ||
p = Numo::NMath.exp(log_pdot - log_z) | ||
{ fnc: log_z - k.dot(x), | ||
jcb: f.transpose.dot(p) - k } | ||
end | ||
end | ||
|
||
let(:res) { described_class.minimize(fnc, x, [f, k], 100, 1e-8, 1e-8, 1e-8, nil) } | ||
let(:err) { (fnc.call(res[:x], [f, k])[:fnc] - fnc.call(Numo::DFloat[0, -0.52, 0.48], [f, k])[:fnc]).abs } | ||
|
||
it { expect(res.keys).to match(%i[x n_iter]) } | ||
it { expect(err).to be < 1.0e-5 } | ||
end | ||
end |