-
Notifications
You must be signed in to change notification settings - Fork 31
/
utils.c
137 lines (118 loc) · 4.85 KB
/
utils.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include "utils.h"
RUBY_EXTERN VALUE mRumale;
#define SIGMA_INIT 1e-4
#define BETA_MIN 1e-15
#define BETA_MAX 1e+15
/**
* @!visibility private
* The multivariate optimization with the scaled conjugate gradient method.
*
* *Reference*
* 1. Moller, M F., "A Scaled Conjugate Gradient Algorithm for Fast Supervised Learning," Neural Networks, Vol. 6, pp. 525--533, 1993.
*
* @overload minimize(fnc, x, args, max_iter, xtol, ftol, jtol, logger)
* @param fnc [Method/Proc] The method for calculating the objective function to be minimized and its gradient vector.
* @param x [Numo::DFloat] (shape: n_dimensions) The initial points.
* @param args [Array] The arguments pass to the 'fnc'.
* @param max_iter [Integer] The maximum number of iterations.
* @param xtol [Float] The tolerance for termination for the optimal vector norm.
* @param ftol [Float] The tolerance for termination for the objective function value.
* @param jtol [Float] The tolerance for termination for the gradient norm.
* @param logger [Method/Proc/Nil] The method for logging.
*
* @return [Hash] { x:, n_iter: } The minimization resuts.
*/
static VALUE
minimize(VALUE self, VALUE fnc, VALUE x, VALUE args, VALUE _max_iter, VALUE _xtol, VALUE _ftol, VALUE _jtol, VALUE logger)
{
const int max_iter = NUM2INT(_max_iter);
const double xtol = NUM2DBL(_xtol);
const double ftol = NUM2DBL(_ftol);
const double jtol = NUM2DBL(_jtol);
const int n_dimensions = NUM2INT(rb_funcall(x, rb_intern("size"), 0));
int n_iter = 0;
char success = 1;
int n_successes = 0;
VALUE ret = rb_funcall(fnc, rb_intern("call"), 2, x, args);
double f_prev = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc"))));
double f_next = 0.0;
double f_curr = f_prev;
VALUE j_prev = rb_hash_aref(ret, ID2SYM(rb_intern("jcb")));
VALUE j_next = j_prev;
VALUE d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1));
double alpha = 0.0;
double beta = 1.0;
double gamma = 0.0;
double delta = 0.0;
double ddelta = 0.0;
double theta = 0.0;
double kappa = 0.0;
double mu = 0.0;
double sigma = 0.0;
VALUE x_next = Qnil;
VALUE x_plus = Qnil;
VALUE j_plus = Qnil;
VALUE result = rb_hash_new();
for (n_iter = 0; n_iter < max_iter; n_iter++) {
if (success == 1) {
mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next));
if (mu >= 0.0) {
d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1));
mu = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, j_next));
}
kappa = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, d));
if (kappa < 1e-16) break;
sigma = SIGMA_INIT / sqrt(kappa);
x_plus = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(sigma)));
ret = rb_funcall(fnc, rb_intern("call"), 2, x_plus, args);
j_plus = rb_hash_aref(ret, ID2SYM(rb_intern("jcb")));
theta = NUM2DBL(rb_funcall(d, rb_intern("dot"), 1, rb_funcall(j_plus, rb_intern("-"), 1, j_next))) / sigma;
}
delta = theta + beta * kappa;
if (delta <= 0) {
delta = beta * kappa;
beta -= theta / kappa;
}
alpha = -mu / delta;
x_next = rb_funcall(x, rb_intern("+"), 1, rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha)));
ret = rb_funcall(fnc, rb_intern("call"), 2, x_next, args);
f_next = NUM2DBL(rb_hash_aref(ret, ID2SYM(rb_intern("fnc"))));
if (isinf(f_next) || isnan(f_next)) break;
ddelta = 2 * (f_next - f_prev) / (alpha * mu);
if (ddelta >= 0) {
success = 1;
n_successes++;
x = x_next;
f_curr = f_next;
} else {
success = 0;
f_curr = f_prev;
}
if (!NIL_P(logger)) rb_funcall(logger, rb_intern("call"), 2, INT2NUM(n_iter + 1), DBL2NUM(f_curr));
if (success == 1) {
if (fabs(f_next - f_prev) < ftol) break;
if (NUM2DBL(rb_funcall(rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(alpha)), rb_intern("abs"), 0), rb_intern("max"), 0)) < xtol) break;
f_prev = f_next;
j_prev = j_next;
j_next = rb_hash_aref(ret, ID2SYM(rb_intern("jcb")));
if (NUM2DBL(rb_funcall(j_next, rb_intern("dot"), 1, j_next)) < jtol) break;
}
if (ddelta < 0.25) beta = beta * 4 < BETA_MAX ? beta * 4 : BETA_MAX;
if (ddelta > 0.75) beta = beta / 4 > BETA_MIN ? beta / 4 : BETA_MIN;
if (n_successes == n_dimensions) {
d = rb_funcall(j_next, rb_intern("*"), 1, DBL2NUM(-1));
n_successes = 0;
} else if (success == 1) {
gamma = NUM2DBL(rb_funcall(rb_funcall(j_prev, rb_intern("-"), 1, j_next), rb_intern("dot"), 1, j_next)) / mu;
d = rb_funcall(rb_funcall(d, rb_intern("*"), 1, DBL2NUM(gamma)), rb_intern("-"), 1, j_next);
}
}
rb_hash_aset(result, ID2SYM(rb_intern("x")), x);
rb_hash_aset(result, ID2SYM(rb_intern("n_iter")), INT2NUM(n_iter));
return result;
}
void init_utils_module()
{
VALUE mUtils = rb_define_module_under(mRumale, "Utils");
rb_define_module_function(mUtils, "minimize", minimize, 8);
}