Skip to content

Commit

Permalink
✨ Add minimization method to Utils module
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Sep 22, 2020
1 parent 3a9a29d commit 1d429fd
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 0 deletions.
1 change: 1 addition & 0 deletions ext/rumale/rumale.c
Expand Up @@ -7,4 +7,5 @@ void Init_rumale(void)
mRumale = rb_define_module("Rumale");

init_tree_module();
init_utils_module();
}
1 change: 1 addition & 0 deletions ext/rumale/rumale.h
Expand Up @@ -4,5 +4,6 @@
#include <ruby.h>

#include "tree.h"
#include "utils.h"

#endif /* RUMALE_H */
137 changes: 137 additions & 0 deletions ext/rumale/utils.c
@@ -0,0 +1,137 @@
#include "utils.h"

RUBY_EXTERN VALUE mRumale;

#define SIGMA_INIT 1e-4
#define BETA_MIN 1e-15
#define BETA_MAX 1e+15

/**
* @!visibility private
* The multivariate optimization with the scaled conjugate gradient method.
*
* *Reference*
* 1. Moller, M F., "A Scaled Conjugate Gradient Algorithm for Fast Supervised Learning," Neural Networks, Vol. 6, pp. 525--533, 1993.
*
* @overload minimize(fnc, x, args, max_iter, xtol, ftol, jtol, logger)
* @param fnc [Method/Proc] The method for calculating the objective function to be minimized and its gradient vector.
* @param x [Numo::DFloat] (shape: n_dimensions) The initial points.
* @param args [Array] The arguments pass to the 'fnc'.
* @param max_iter [Integer] The maximum number of iterations.
* @param xtol [Float] The tolerance for termination for the optimal vector norm.
* @param ftol [Float] The tolerance for termination for the objective function value.
* @param jtol [Float] The tolerance for termination for the gradient norm.
* @param logger [Method/Proc/Nil] The method for logging.
*
* @return [Hash] { x:, n_iter: } The minimization resuts.
*/
static VALUE
minimize(VALUE self, VALUE fnc, VALUE x, VALUE args, VALUE _max_iter, VALUE _xtol, VALUE _ftol, VALUE _jtol, VALUE logger)
{
const int max_iter = NUM2INT(_max_iter);
const double xtol = NUM2DBL(_xtol);
const double ftol = NUM2DBL(_ftol);
const double jtol = NUM2DBL(_jtol);
const int n_dimensions = NUM2INT(rb_funcall(x, rb_intern("size"), 0));
int n_iter = 0;
char success = 1;
int n_successes = 0;
VALUE ret = rb_funcall(fnc, rb_intern("call"), 2, x, args);
double f_prev = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc"))));
double f_next = 0.0;
double f_curr = f_prev;
VALUE j_prev = rb_hash_aref(ret, ID2SYM(rb_intern("jcb")));
VALUE j_next = j_prev;
VALUE d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1));
double alpha = 0.0;
double beta = 1.0;
double gamma = 0.0;
double delta = 0.0;
double ddelta = 0.0;
double theta = 0.0;
double kappa = 0.0;
double mu = 0.0;
double sigma = 0.0;
VALUE x_next = Qnil;
VALUE x_plus = Qnil;
VALUE j_plus = Qnil;
VALUE result = rb_hash_new();

for (n_iter = 0; n_iter < max_iter; n_iter++) {
if (success == 1) {
mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next));
if (mu >= 0.0) {
d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1));
mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next));
}
kappa = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, d));
if (kappa < 1e-16) break;

sigma = SIGMA_INIT / sqrt(kappa);
x_plus = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(sigma)));
ret = rb_funcall(fnc, rb_intern("call"), 2, x_plus, args);
j_plus = rb_hash_aref(ret, ID2SYM(rb_intern("jcb")));
theta = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, rb_funcall(j_plus, rb_intern("-"), 1, j_next))) / sigma;
}

delta = theta + beta * kappa;
if (delta <= 0) {
delta = beta * kappa;
beta -= theta / kappa;
}
alpha = -mu / delta;

x_next = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha)));
ret = rb_funcall(fnc, rb_intern("call"), 2, x_next, args);
f_next = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc"))));

if (isinf(f_next) || isnan(f_next)) break;

ddelta = 2 * (f_next - f_prev) / (alpha * mu);
if (ddelta >= 0) {
success = 1;
n_successes++;
x = x_next;
f_curr = f_next;
} else {
success = 0;
f_curr = f_prev;
}

if (!NIL_P(logger)) rb_funcall(logger, rb_intern("call"), 2, INT2NUM(n_iter + 1), DBL2NUM(f_curr));

if (success == 1) {
if (fabs(f_next - f_prev) < ftol) break;
if (NUM2DBL(rb_funcall(rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha)), rb_intern("abs"), 0), rb_intern("max"), 0)) < xtol) break;

f_prev = f_next;
j_prev = j_next;
j_next = rb_hash_aref(ret, ID2SYM(rb_intern("jcb")));

if (NUM2DBL(rb_funcall(j_next, rb_intern("dot"), 1, j_next)) < jtol) break;
}

if (ddelta < 0.25) beta = beta * 4 < BETA_MAX ? beta * 4 : BETA_MAX;
if (ddelta > 0.75) beta = beta / 4 > BETA_MIN ? beta / 4 : BETA_MIN;

if (n_successes == n_dimensions) {
d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1));
n_successes = 0;
} else if (success == 1) {
gamma = NUM2DBL(rb_funcall(rb_funcall(j_prev, rb_intern("-"), 1, j_next), rb_intern("dot"), 1, j_next)) / mu;
d = rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(gamma)), rb_intern("-"), 1, j_next);
}
}

rb_hash_aset(result, ID2SYM(rb_intern("x")), x);
rb_hash_aset(result, ID2SYM(rb_intern("n_iter")), INT2NUM(n_iter));

return result;
}

void init_utils_module()
{
VALUE mUtils = rb_define_module_under(mRumale, "Utils");

rb_define_module_function(mUtils, "minimize", minimize, 8);
}
11 changes: 11 additions & 0 deletions ext/rumale/utils.h
@@ -0,0 +1,11 @@
#ifndef RUMALE_UTILS_H
#define RUMALE_UTILS_H 1

#include <math.h>
#include <ruby.h>
#include <numo/narray.h>
#include <numo/template.h>

void init_utils_module();

#endif /* RUMALE_UTILS_H */
26 changes: 26 additions & 0 deletions spec/rumale/utils_spec.rb
@@ -0,0 +1,26 @@
# frozen_string_literal: true

RSpec.describe Rumale::Utils do
describe '#minimize' do
let(:x) { Numo::DFloat.zeros(3) }
let(:f) { Numo::DFloat[[1, 1, 1], [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 0]] }
let(:k) { Numo::DFloat[1.0, 0.3, 0.5] }

let(:fnc) do
proc do |x, args|
f, k = args
log_pdot = f.dot(x)
log_z = Math.log(Numo::NMath.exp(log_pdot).sum)
p = Numo::NMath.exp(log_pdot - log_z)
{ fnc: log_z - k.dot(x),
jcb: f.transpose.dot(p) - k }
end
end

let(:res) { described_class.minimize(fnc, x, [f, k], 100, 1e-8, 1e-8, 1e-8, nil) }
let(:err) { (fnc.call(res[:x], [f, k])[:fnc] - fnc.call(Numo::DFloat[0, -0.52, 0.48], [f, k])[:fnc]).abs }

it { expect(res.keys).to match(%i[x n_iter]) }
it { expect(err).to be < 1.0e-5 }
end
end

0 comments on commit 1d429fd

Please sign in to comment.