Skip to content
Permalink
Browse files

[Project] Support callback function for train

  • Loading branch information
vstakhov committed Jun 30, 2019
1 parent b0dc150 commit 95edae6494dac4acf6ab19714a45339e515b8c49
Showing with 8 additions and 2 deletions.
  1. +8 −2 contrib/kann/kann.c
@@ -846,7 +846,10 @@ float kann_grad_clip(float thres, int n, float *g)
*** @@XY: simpler API for network with a single input/output ***
****************************************************************/

int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max_drop_streak, float frac_val, int n, float **_x, float **_y)
int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch,
int max_drop_streak, float frac_val, int n,
float **_x, float **_y,
kann_train_cb cb, void *ud)
{
int i, j, *shuf, n_train, n_val, n_in, n_out, n_var, n_const, drop_streak = 0, min_set = 0;
float **x, **y, *x1, *y1, *r, min_val_cost = FLT_MAX, *min_x, *min_c;
@@ -907,14 +910,17 @@ int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch, int max
n_proc += ms;
}
if (n_val > 0) val_cost /= n_val;
if (kann_verbose >= 3) {
if (cb) {
cb(i + 1, train_cost, val_cost, ud);
#if 0
fprintf(stderr, "epoch: %d; training cost: %g", i+1, train_cost);
if (n_train_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_train_err / n_train);
if (n_val > 0) {
fprintf(stderr, "; validation cost: %g", val_cost);
if (n_val_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_val_err / n_val);
}
fputc('\n', stderr);
#endif
}
if (i >= max_drop_streak && n_val > 0) {
if (val_cost < min_val_cost) {

0 comments on commit 95edae6

Please sign in to comment.
You can’t perform that action at this time.