From 1d429fdac5b05a3ea8eb61345feacd0f05768342 Mon Sep 17 00:00:00 2001 From: yoshoku Date: Wed, 9 Sep 2020 01:03:59 +0900 Subject: [PATCH] :sparkles: Add minimization method to Utils module --- ext/rumale/rumale.c | 1 + ext/rumale/rumale.h | 1 + ext/rumale/utils.c | 137 ++++++++++++++++++++++++++++++++++++++ ext/rumale/utils.h | 11 +++ spec/rumale/utils_spec.rb | 26 ++++++++ 5 files changed, 176 insertions(+) create mode 100644 ext/rumale/utils.c create mode 100644 ext/rumale/utils.h create mode 100644 spec/rumale/utils_spec.rb diff --git a/ext/rumale/rumale.c b/ext/rumale/rumale.c index 07810417..f0dd8aa6 100644 --- a/ext/rumale/rumale.c +++ b/ext/rumale/rumale.c @@ -7,4 +7,5 @@ void Init_rumale(void) mRumale = rb_define_module("Rumale"); init_tree_module(); + init_utils_module(); } diff --git a/ext/rumale/rumale.h b/ext/rumale/rumale.h index ec09cf29..f9536846 100644 --- a/ext/rumale/rumale.h +++ b/ext/rumale/rumale.h @@ -4,5 +4,6 @@ #include #include "tree.h" +#include "utils.h" #endif /* RUMALE_H */ diff --git a/ext/rumale/utils.c b/ext/rumale/utils.c new file mode 100644 index 00000000..df092a0c --- /dev/null +++ b/ext/rumale/utils.c @@ -0,0 +1,137 @@ +#include "utils.h" + +RUBY_EXTERN VALUE mRumale; + +#define SIGMA_INIT 1e-4 +#define BETA_MIN 1e-15 +#define BETA_MAX 1e+15 + +/** + * @!visibility private + * The multivariate optimization with the scaled conjugate gradient method. + * + * *Reference* + * 1. Moller, M F., "A Scaled Conjugate Gradient Algorithm for Fast Supervised Learning," Neural Networks, Vol. 6, pp. 525--533, 1993. + * + * @overload minimize(fnc, x, args, max_iter, xtol, ftol, jtol, logger) + * @param fnc [Method/Proc] The method for calculating the objective function to be minimized and its gradient vector. + * @param x [Numo::DFloat] (shape: n_dimensions) The initial points. + * @param args [Array] The arguments pass to the 'fnc'. + * @param max_iter [Integer] The maximum number of iterations. + * @param xtol [Float] The tolerance for termination for the optimal vector norm. + * @param ftol [Float] The tolerance for termination for the objective function value. + * @param jtol [Float] The tolerance for termination for the gradient norm. + * @param logger [Method/Proc/Nil] The method for logging. + * + * @return [Hash] { x:, n_iter: } The minimization resuts. + */ +static VALUE +minimize(VALUE self, VALUE fnc, VALUE x, VALUE args, VALUE _max_iter, VALUE _xtol, VALUE _ftol, VALUE _jtol, VALUE logger) +{ + const int max_iter = NUM2INT(_max_iter); + const double xtol = NUM2DBL(_xtol); + const double ftol = NUM2DBL(_ftol); + const double jtol = NUM2DBL(_jtol); + const int n_dimensions = NUM2INT(rb_funcall(x, rb_intern("size"), 0)); + int n_iter = 0; + char success = 1; + int n_successes = 0; + VALUE ret = rb_funcall(fnc, rb_intern("call"), 2, x, args); + double f_prev = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc")))); + double f_next = 0.0; + double f_curr = f_prev; + VALUE j_prev = rb_hash_aref(ret, ID2SYM(rb_intern("jcb"))); + VALUE j_next = j_prev; + VALUE d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1)); + double alpha = 0.0; + double beta = 1.0; + double gamma = 0.0; + double delta = 0.0; + double ddelta = 0.0; + double theta = 0.0; + double kappa = 0.0; + double mu = 0.0; + double sigma = 0.0; + VALUE x_next = Qnil; + VALUE x_plus = Qnil; + VALUE j_plus = Qnil; + VALUE result = rb_hash_new(); + + for (n_iter = 0; n_iter < max_iter; n_iter++) { + if (success == 1) { + mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next)); + if (mu >= 0.0) { + d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1)); + mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next)); + } + kappa = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, d)); + if (kappa < 1e-16) break; + + sigma = SIGMA_INIT / sqrt(kappa); + x_plus = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(sigma))); + ret = rb_funcall(fnc, rb_intern("call"), 2, x_plus, args); + j_plus = rb_hash_aref(ret, ID2SYM(rb_intern("jcb"))); + theta = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, rb_funcall(j_plus, rb_intern("-"), 1, j_next))) / sigma; + } + + delta = theta + beta * kappa; + if (delta <= 0) { + delta = beta * kappa; + beta -= theta / kappa; + } + alpha = -mu / delta; + + x_next = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha))); + ret = rb_funcall(fnc, rb_intern("call"), 2, x_next, args); + f_next = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc")))); + + if (isinf(f_next) || isnan(f_next)) break; + + ddelta = 2 * (f_next - f_prev) / (alpha * mu); + if (ddelta >= 0) { + success = 1; + n_successes++; + x = x_next; + f_curr = f_next; + } else { + success = 0; + f_curr = f_prev; + } + + if (!NIL_P(logger)) rb_funcall(logger, rb_intern("call"), 2, INT2NUM(n_iter + 1), DBL2NUM(f_curr)); + + if (success == 1) { + if (fabs(f_next - f_prev) < ftol) break; + if (NUM2DBL(rb_funcall(rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha)), rb_intern("abs"), 0), rb_intern("max"), 0)) < xtol) break; + + f_prev = f_next; + j_prev = j_next; + j_next = rb_hash_aref(ret, ID2SYM(rb_intern("jcb"))); + + if (NUM2DBL(rb_funcall(j_next, rb_intern("dot"), 1, j_next)) < jtol) break; + } + + if (ddelta < 0.25) beta = beta * 4 < BETA_MAX ? beta * 4 : BETA_MAX; + if (ddelta > 0.75) beta = beta / 4 > BETA_MIN ? beta / 4 : BETA_MIN; + + if (n_successes == n_dimensions) { + d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1)); + n_successes = 0; + } else if (success == 1) { + gamma = NUM2DBL(rb_funcall(rb_funcall(j_prev, rb_intern("-"), 1, j_next), rb_intern("dot"), 1, j_next)) / mu; + d = rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(gamma)), rb_intern("-"), 1, j_next); + } + } + + rb_hash_aset(result, ID2SYM(rb_intern("x")), x); + rb_hash_aset(result, ID2SYM(rb_intern("n_iter")), INT2NUM(n_iter)); + + return result; +} + +void init_utils_module() +{ + VALUE mUtils = rb_define_module_under(mRumale, "Utils"); + + rb_define_module_function(mUtils, "minimize", minimize, 8); +} diff --git a/ext/rumale/utils.h b/ext/rumale/utils.h new file mode 100644 index 00000000..4c4c6cc9 --- /dev/null +++ b/ext/rumale/utils.h @@ -0,0 +1,11 @@ +#ifndef RUMALE_UTILS_H +#define RUMALE_UTILS_H 1 + +#include +#include +#include +#include + +void init_utils_module(); + +#endif /* RUMALE_UTILS_H */ diff --git a/spec/rumale/utils_spec.rb b/spec/rumale/utils_spec.rb new file mode 100644 index 00000000..d7dd1397 --- /dev/null +++ b/spec/rumale/utils_spec.rb @@ -0,0 +1,26 @@ +# frozen_string_literal: true + +RSpec.describe Rumale::Utils do + describe '#minimize' do + let(:x) { Numo::DFloat.zeros(3) } + let(:f) { Numo::DFloat[[1, 1, 1], [1, 1, 0], [1, 0, 1], [1, 0, 0], [1, 0, 0]] } + let(:k) { Numo::DFloat[1.0, 0.3, 0.5] } + + let(:fnc) do + proc do |x, args| + f, k = args + log_pdot = f.dot(x) + log_z = Math.log(Numo::NMath.exp(log_pdot).sum) + p = Numo::NMath.exp(log_pdot - log_z) + { fnc: log_z - k.dot(x), + jcb: f.transpose.dot(p) - k } + end + end + + let(:res) { described_class.minimize(fnc, x, [f, k], 100, 1e-8, 1e-8, 1e-8, nil) } + let(:err) { (fnc.call(res[:x], [f, k])[:fnc] - fnc.call(Numo::DFloat[0, -0.52, 0.48], [f, k])[:fnc]).abs } + + it { expect(res.keys).to match(%i[x n_iter]) } + it { expect(err).to be < 1.0e-5 } + end +end