Skip to content

Commit

Permalink
diagnostics, saving, opt exn
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbdev committed Jul 29, 2015
1 parent d9807f8 commit daa8f08
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
14 changes: 12 additions & 2 deletions clstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,18 @@ Network layer(const string &kind,
const Networks &subs) {
Network net;
auto it = layer_factories.find(kind);
if (it != layer_factories.end())
net.reset(it->second());
if (it != layer_factories.end()) {
net.reset(it->second());
} else {
string accepted_layer_kinds = "";
for (auto val : layer_factories) {
accepted_layer_kinds += val.first;
accepted_layer_kinds += ",";
}
THROW("unknown layer type:" + kind +
". Accepted layer kinds:" + accepted_layer_kinds);
}

for (auto it : args) {
net->attributes[it.first] = it.second;
}
Expand Down
2 changes: 2 additions & 0 deletions clstmctc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,10 @@ int main(int argc, char **argv) {
} else {
return main_eval(argc, argv);
}
#ifndef NOEXCEPTION
} catch(const char *msg) {
print("EXCEPTION", msg);
#endif
} catch(...) {
print("UNKNOWN EXCEPTION");
}
Expand Down
18 changes: 15 additions & 3 deletions clstmfiltertrain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ int main1(int argc, char **argv) {
vector<Sample> samples, test_samples;
read_samples(samples, argv[1]);
if (argc > 2) read_samples(test_samples, argv[2]);
print("got", samples.size(), "files,", test_samples.size(), "tests");
print("got", samples.size(), "inputs,", test_samples.size(), "tests");

vector<int> icodec, codec;
get_codec(icodec, samples, &Sample::in);
Expand All @@ -88,6 +88,10 @@ int main1(int argc, char **argv) {
int report_every = getienv("report_every", 100);
int test_every = getienv("test_every", 10000);

// Command to execute after testing the networks performance.
string after_test = getsenv("after_test", "");

double best_error = 1e38;
double test_error = 9999.0;
for (int trial = 0; trial < maxtrain; trial++) {
int sample = irandom() % samples.size();
Expand All @@ -102,14 +106,22 @@ int main1(int argc, char **argv) {
}
test_error = errors/count;
print("ERROR", trial, test_error, " ", errors, count);
if (save_every == 0 && test_error < best_error) {
best_error = test_error;
print("saving best performing network so far", save_name,
"error rate: ", best_error);
string fname = save_name + ".h5";
clstm.save(fname);
}
if (after_test != "") system(after_test.c_str());
}
if (trial > 0 && save_every > 0 && trial%save_every == 0) {
string fname = save_name+"-"+to_string(trial)+".model.proto";
string fname = save_name + "-" + to_string(trial) + ".h5";
clstm.save(fname);
}
wstring pred = clstm.train(samples[sample].in, samples[sample].out);
if (trial%report_every == 0) {
print(trial);
print("trial", trial);
print("INP", samples[sample].in);
print("TRU", samples[sample].out);
print("ALN", clstm.aligned_utf8());
Expand Down
2 changes: 2 additions & 0 deletions clstmtext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,10 @@ int main(int argc, char **argv) {
} else if (mode == "filter") {
return main_filter(argc, argv);
}
#ifndef NOEXCEPTION
} catch(const char *msg) {
print("EXCEPTION", msg);
#endif
} catch(...) {
print("UNKNOWN EXCEPTION");
}
Expand Down

0 comments on commit daa8f08

Please sign in to comment.