From daa8f08e060d6d2bc415becc32eba36c0b4de524 Mon Sep 17 00:00:00 2001 From: tmbdev Date: Wed, 29 Jul 2015 15:18:31 -0700 Subject: [PATCH] diagnostics, saving, opt exn --- clstm.cc | 14 ++++++++++++-- clstmctc.cc | 2 ++ clstmfiltertrain.cc | 18 +++++++++++++++--- clstmtext.cc | 2 ++ 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/clstm.cc b/clstm.cc index 6b6b8d2..0ed805a 100644 --- a/clstm.cc +++ b/clstm.cc @@ -61,8 +61,18 @@ Network layer(const string &kind, const Networks &subs) { Network net; auto it = layer_factories.find(kind); - if (it != layer_factories.end()) - net.reset(it->second()); + if (it != layer_factories.end()) { + net.reset(it->second()); + } else { + string accepted_layer_kinds = ""; + for (auto val : layer_factories) { + accepted_layer_kinds += val.first; + accepted_layer_kinds += ","; + } + THROW("unknown layer type:" + kind + + ". Accepted layer kinds:" + accepted_layer_kinds); + } + for (auto it : args) { net->attributes[it.first] = it.second; } diff --git a/clstmctc.cc b/clstmctc.cc index 3a33fcd..a09fdb3 100644 --- a/clstmctc.cc +++ b/clstmctc.cc @@ -556,8 +556,10 @@ int main(int argc, char **argv) { } else { return main_eval(argc, argv); } +#ifndef NOEXCEPTION } catch(const char *msg) { print("EXCEPTION", msg); +#endif } catch(...) { print("UNKNOWN EXCEPTION"); } diff --git a/clstmfiltertrain.cc b/clstmfiltertrain.cc index 95c2e3e..f0049d5 100644 --- a/clstmfiltertrain.cc +++ b/clstmfiltertrain.cc @@ -71,7 +71,7 @@ int main1(int argc, char **argv) { vector samples, test_samples; read_samples(samples, argv[1]); if (argc > 2) read_samples(test_samples, argv[2]); - print("got", samples.size(), "files,", test_samples.size(), "tests"); + print("got", samples.size(), "inputs,", test_samples.size(), "tests"); vector icodec, codec; get_codec(icodec, samples, &Sample::in); @@ -88,6 +88,10 @@ int main1(int argc, char **argv) { int report_every = getienv("report_every", 100); int test_every = getienv("test_every", 10000); + // Command to execute after testing the networks performance. + string after_test = getsenv("after_test", ""); + + double best_error = 1e38; double test_error = 9999.0; for (int trial = 0; trial < maxtrain; trial++) { int sample = irandom() % samples.size(); @@ -102,14 +106,22 @@ int main1(int argc, char **argv) { } test_error = errors/count; print("ERROR", trial, test_error, " ", errors, count); + if (save_every == 0 && test_error < best_error) { + best_error = test_error; + print("saving best performing network so far", save_name, + "error rate: ", best_error); + string fname = save_name + ".h5"; + clstm.save(fname); + } + if (after_test != "") system(after_test.c_str()); } if (trial > 0 && save_every > 0 && trial%save_every == 0) { - string fname = save_name+"-"+to_string(trial)+".model.proto"; + string fname = save_name + "-" + to_string(trial) + ".h5"; clstm.save(fname); } wstring pred = clstm.train(samples[sample].in, samples[sample].out); if (trial%report_every == 0) { - print(trial); + print("trial", trial); print("INP", samples[sample].in); print("TRU", samples[sample].out); print("ALN", clstm.aligned_utf8()); diff --git a/clstmtext.cc b/clstmtext.cc index 53b19ee..546c947 100644 --- a/clstmtext.cc +++ b/clstmtext.cc @@ -425,8 +425,10 @@ int main(int argc, char **argv) { } else if (mode == "filter") { return main_filter(argc, argv); } +#ifndef NOEXCEPTION } catch(const char *msg) { print("EXCEPTION", msg); +#endif } catch(...) { print("UNKNOWN EXCEPTION"); }