diff --git a/clstm.cc b/clstm.cc index ed39590..4adc348 100644 --- a/clstm.cc +++ b/clstm.cc @@ -673,6 +673,34 @@ void average_weights(vector &networks) { distribute_weights(networks); } +int n_states(Network net) { + int total = 0; + walk_states(net, [&](const string &, Sequence *p) { + total += p->size() * p->rows() * p->cols(); + }); + return total; +} + +bool get_states(Network net, Float *data, int total, int gpu) { + int index = 0; + walk_states(net, [&](const string &, Sequence *p) { + for (int t = 0; t < p->size(); t++) + for (int i = 0; i < p->rows(); i++) + for (int b = 0; b < p->cols(); b++) data[index++] = (*p)[t].v(i, b); + }); + return total == index; +} + +bool set_states(Network net, const Float *data, int total, int gpu) { + int index = 0; + walk_states(net, [&](const string &, Sequence *p) { + for (int t = 0; t < p->size(); t++) + for (int i = 0; i < p->rows(); i++) + for (int b = 0; b < p->cols(); b++) (*p)[t].v(i, b) = data[index++]; + }); + return total == index; +} + int n_params(Network net) { int total = 0; walk_params(net, @@ -680,7 +708,7 @@ int n_params(Network net) { return total; } -void share_params(Network net, Float *params, int total, int gpu) { +bool share_params(Network net, Float *params, int total, int gpu) { int index = 0; walk_params(net, [&](const string &, Params *p) { int n = p->v.rows(); @@ -688,10 +716,10 @@ void share_params(Network net, Float *params, int total, int gpu) { p->v.displaceTo(params + index, n, m, gpu); index += p->v.total_size(); }); - assert(index == total); + return index == total; } -void set_params(Network net, Float *params, int total, int gpu) { +bool set_params(Network net, const Float *params, int total, int gpu) { assert(gpu < 0); int index = 0; walk_params(net, [&](const string &, Params *p) { @@ -701,10 +729,10 @@ void set_params(Network net, Float *params, int total, int gpu) { memcpy(p->v.ptr, params + index, nbytes); index += p->v.total_size(); }); - assert(index == total); + return index == total; } -void get_params(Network net, Float *params, int total, int gpu) { +bool get_params(Network net, Float *params, int total, int gpu) { assert(gpu < 0); int index = 0; walk_params(net, [&](const string &, Params *p) { @@ -714,14 +742,15 @@ void get_params(Network net, Float *params, int total, int gpu) { memcpy(params + index, p->v.ptr, nbytes); index += p->v.total_size(); }); - assert(index == total); + return index == total; } -void clear_derivs(Network net) { +bool clear_derivs(Network net) { walk_params(net, [&](const string &, Params *p) { p->d.setZero(); }); + return true; } -void get_derivs(Network net, Float *params, int total, int gpu) { +bool get_derivs(Network net, Float *params, int total, int gpu) { assert(gpu < 0); int index = 0; walk_params(net, [&](const string &, Params *p) { @@ -731,7 +760,7 @@ void get_derivs(Network net, Float *params, int total, int gpu) { memcpy(params + index, p->d.ptr, nbytes); index += p->v.total_size(); }); - assert(index == total); + return index == total; } } // namespace ocropus diff --git a/clstm.h b/clstm.h index a0c0f59..661ac71 100644 --- a/clstm.h +++ b/clstm.h @@ -156,11 +156,14 @@ void network_info(Network net, string prefix = ""); void network_detail(Network net, string prefix = ""); int n_params(Network net); -void clear_derivs(Network net); -void share_params(Network net, Float *params, int total, int gpu = -1); -void set_params(Network net, Float *params, int total, int gpu = -1); -void get_params(Network net, Float *params, int total, int gpu = -1); -void get_derivs(Network net, Float *params, int total, int gpu = -1); +bool clear_derivs(Network net); +bool share_params(Network net, Float *params, int total, int gpu = -1); +bool set_params(Network net, const Float *params, int total, int gpu = -1); +bool get_params(Network net, Float *params, int total, int gpu = -1); +bool get_derivs(Network net, Float *params, int total, int gpu = -1); +int n_states(Network net); +bool set_states(Network net, const Float *params, int total, int gpu = -1); +bool get_states(Network net, Float *params, int total, int gpu = -1); // setting inputs and outputs void set_classes(Network net, BatchClasses &classes); diff --git a/clstmfiltertrain.cc b/clstmfiltertrain.cc index 0870ba4..a54f97d 100644 --- a/clstmfiltertrain.cc +++ b/clstmfiltertrain.cc @@ -120,9 +120,9 @@ int main1(int argc, char **argv) { } test_error = errors / count; double exact_test_error = 1.0 - exact_matches / test_samples.size(); - print("ERROR", trial, test_error, " ", errors, count, - "exact_errors", exact_test_error, - "lrate", lrate, "momentum", momentum, "nhidden", nhidden); + print("ERROR", trial, test_error, " ", errors, count, "exact_errors", + exact_test_error, "lrate", lrate, "momentum", momentum, "nhidden", + nhidden); if (use_exact) test_error = exact_test_error; if (save_every == 0 && test_error < best_error) { best_error = test_error; diff --git a/clstmhl.h b/clstmhl.h index 35c180a..1df36b0 100644 --- a/clstmhl.h +++ b/clstmhl.h @@ -50,7 +50,7 @@ struct CLSTMText { nclasses = net->codec.size(); iclasses = net->icodec.size(); int neps = net->attr.get("neps", -1); - if (neps<0) cerr << "WARNING: no neps\n"; + if (neps < 0) cerr << "WARNING: no neps\n"; return true; } @@ -58,7 +58,7 @@ struct CLSTMText { // exception. void save(const std::string &fname) { if (!maybe_save(fname)) { - THROW("Could not save CLSTMText net to file: " + fname); + THROW("Could not save CLSTMText net to file: " + fname); } } @@ -178,7 +178,7 @@ struct CLSTMOCR { // exception. void save(const std::string &fname) { if (!maybe_save(fname)) { - THROW("Could not save CLSTMOCR net to file: " + fname); + THROW("Could not save CLSTMOCR net to file: " + fname); } } diff --git a/extras.cc b/extras.cc index 27082fb..b042ced 100644 --- a/extras.cc +++ b/extras.cc @@ -295,6 +295,7 @@ INormalizer *make_Normalizer(const string &name) { if (name == "mean") return make_MeanNormalizer(); if (name == "center") return make_CenterNormalizer(); THROW("unknown normalizer name"); + return 0; } // PNG I/O (taken from iulib) diff --git a/tensor.h b/tensor.h index acf68c1..af90cac 100644 --- a/tensor.h +++ b/tensor.h @@ -291,6 +291,7 @@ struct Tensor2 { return value; #else THROW("not compiled for GPU"); + return 0; #endif } } diff --git a/test-lstm.cc b/test-lstm.cc index 2f3c245..5f8a3e9 100644 --- a/test-lstm.cc +++ b/test-lstm.cc @@ -111,12 +111,13 @@ int main(int argc, char **argv) { print("OK", merr); } int nparams = n_params(net); + assert(nparams > 0); print("nparams", nparams); vector params(nparams); vector backup; - get_params(net, ¶ms[0], nparams); + assert(get_params(net, ¶ms[0], nparams)); backup = params; - share_params(net, ¶ms[0], nparams); + assert(share_params(net, ¶ms[0], nparams)); double merr2 = test_net(net); if (merr2 > 0.1) { print("FAILED (params)", merr2); diff --git a/utils.h b/utils.h index 792d21e..9eee24c 100644 --- a/utils.h +++ b/utils.h @@ -30,6 +30,12 @@ using std::endl; using std::cout; using std::cerr; +template +inline void die(const A &arg) { + cerr << "EXCEPTION (" << arg << ")\n"; + exit(255); +} + // get current time down to usec precision as a double inline double now() {