Skip to content

Commit

Permalink
APIs for internal state/params. Reformat.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbdev committed Jan 28, 2016
1 parent 14d1647 commit f37e299
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 22 deletions.
47 changes: 38 additions & 9 deletions clstm.cc
Expand Up @@ -673,25 +673,53 @@ void average_weights(vector<Network> &networks) {
distribute_weights(networks);
}

int n_states(Network net) {
int total = 0;
walk_states(net, [&](const string &, Sequence *p) {
total += p->size() * p->rows() * p->cols();
});
return total;
}

bool get_states(Network net, Float *data, int total, int gpu) {
int index = 0;
walk_states(net, [&](const string &, Sequence *p) {
for (int t = 0; t < p->size(); t++)
for (int i = 0; i < p->rows(); i++)
for (int b = 0; b < p->cols(); b++) data[index++] = (*p)[t].v(i, b);
});
return total == index;
}

bool set_states(Network net, const Float *data, int total, int gpu) {
int index = 0;
walk_states(net, [&](const string &, Sequence *p) {
for (int t = 0; t < p->size(); t++)
for (int i = 0; i < p->rows(); i++)
for (int b = 0; b < p->cols(); b++) (*p)[t].v(i, b) = data[index++];
});
return total == index;
}

int n_params(Network net) {
int total = 0;
walk_params(net,
[&](const string &, Params *p) { total += p->v.total_size(); });
return total;
}

void share_params(Network net, Float *params, int total, int gpu) {
bool share_params(Network net, Float *params, int total, int gpu) {
int index = 0;
walk_params(net, [&](const string &, Params *p) {
int n = p->v.rows();
int m = p->v.cols();
p->v.displaceTo(params + index, n, m, gpu);
index += p->v.total_size();
});
assert(index == total);
return index == total;
}

void set_params(Network net, Float *params, int total, int gpu) {
bool set_params(Network net, const Float *params, int total, int gpu) {
assert(gpu < 0);
int index = 0;
walk_params(net, [&](const string &, Params *p) {
Expand All @@ -701,10 +729,10 @@ void set_params(Network net, Float *params, int total, int gpu) {
memcpy(p->v.ptr, params + index, nbytes);
index += p->v.total_size();
});
assert(index == total);
return index == total;
}

void get_params(Network net, Float *params, int total, int gpu) {
bool get_params(Network net, Float *params, int total, int gpu) {
assert(gpu < 0);
int index = 0;
walk_params(net, [&](const string &, Params *p) {
Expand All @@ -714,14 +742,15 @@ void get_params(Network net, Float *params, int total, int gpu) {
memcpy(params + index, p->v.ptr, nbytes);
index += p->v.total_size();
});
assert(index == total);
return index == total;
}

void clear_derivs(Network net) {
bool clear_derivs(Network net) {
walk_params(net, [&](const string &, Params *p) { p->d.setZero(); });
return true;
}

void get_derivs(Network net, Float *params, int total, int gpu) {
bool get_derivs(Network net, Float *params, int total, int gpu) {
assert(gpu < 0);
int index = 0;
walk_params(net, [&](const string &, Params *p) {
Expand All @@ -731,7 +760,7 @@ void get_derivs(Network net, Float *params, int total, int gpu) {
memcpy(params + index, p->d.ptr, nbytes);
index += p->v.total_size();
});
assert(index == total);
return index == total;
}

} // namespace ocropus
13 changes: 8 additions & 5 deletions clstm.h
Expand Up @@ -156,11 +156,14 @@ void network_info(Network net, string prefix = "");
void network_detail(Network net, string prefix = "");

int n_params(Network net);
void clear_derivs(Network net);
void share_params(Network net, Float *params, int total, int gpu = -1);
void set_params(Network net, Float *params, int total, int gpu = -1);
void get_params(Network net, Float *params, int total, int gpu = -1);
void get_derivs(Network net, Float *params, int total, int gpu = -1);
bool clear_derivs(Network net);
bool share_params(Network net, Float *params, int total, int gpu = -1);
bool set_params(Network net, const Float *params, int total, int gpu = -1);
bool get_params(Network net, Float *params, int total, int gpu = -1);
bool get_derivs(Network net, Float *params, int total, int gpu = -1);
int n_states(Network net);
bool set_states(Network net, const Float *params, int total, int gpu = -1);
bool get_states(Network net, Float *params, int total, int gpu = -1);

// setting inputs and outputs
void set_classes(Network net, BatchClasses &classes);
Expand Down
6 changes: 3 additions & 3 deletions clstmfiltertrain.cc
Expand Up @@ -120,9 +120,9 @@ int main1(int argc, char **argv) {
}
test_error = errors / count;
double exact_test_error = 1.0 - exact_matches / test_samples.size();
print("ERROR", trial, test_error, " ", errors, count,
"exact_errors", exact_test_error,
"lrate", lrate, "momentum", momentum, "nhidden", nhidden);
print("ERROR", trial, test_error, " ", errors, count, "exact_errors",
exact_test_error, "lrate", lrate, "momentum", momentum, "nhidden",
nhidden);
if (use_exact) test_error = exact_test_error;
if (save_every == 0 && test_error < best_error) {
best_error = test_error;
Expand Down
6 changes: 3 additions & 3 deletions clstmhl.h
Expand Up @@ -50,15 +50,15 @@ struct CLSTMText {
nclasses = net->codec.size();
iclasses = net->icodec.size();
int neps = net->attr.get("neps", -1);
if (neps<0) cerr << "WARNING: no neps\n";
if (neps < 0) cerr << "WARNING: no neps\n";
return true;
}

// Saves the network to the given file. If this operation fails, throws an
// exception.
void save(const std::string &fname) {
if (!maybe_save(fname)) {
THROW("Could not save CLSTMText net to file: " + fname);
THROW("Could not save CLSTMText net to file: " + fname);
}
}

Expand Down Expand Up @@ -178,7 +178,7 @@ struct CLSTMOCR {
// exception.
void save(const std::string &fname) {
if (!maybe_save(fname)) {
THROW("Could not save CLSTMOCR net to file: " + fname);
THROW("Could not save CLSTMOCR net to file: " + fname);
}
}

Expand Down
1 change: 1 addition & 0 deletions extras.cc
Expand Up @@ -295,6 +295,7 @@ INormalizer *make_Normalizer(const string &name) {
if (name == "mean") return make_MeanNormalizer();
if (name == "center") return make_CenterNormalizer();
THROW("unknown normalizer name");
return 0;
}

// PNG I/O (taken from iulib)
Expand Down
1 change: 1 addition & 0 deletions tensor.h
Expand Up @@ -291,6 +291,7 @@ struct Tensor2 {
return value;
#else
THROW("not compiled for GPU");
return 0;
#endif
}
}
Expand Down
5 changes: 3 additions & 2 deletions test-lstm.cc
Expand Up @@ -111,12 +111,13 @@ int main(int argc, char **argv) {
print("OK", merr);
}
int nparams = n_params(net);
assert(nparams > 0);
print("nparams", nparams);
vector<float> params(nparams);
vector<float> backup;
get_params(net, &params[0], nparams);
assert(get_params(net, &params[0], nparams));
backup = params;
share_params(net, &params[0], nparams);
assert(share_params(net, &params[0], nparams));
double merr2 = test_net(net);
if (merr2 > 0.1) {
print("FAILED (params)", merr2);
Expand Down
6 changes: 6 additions & 0 deletions utils.h
Expand Up @@ -30,6 +30,12 @@ using std::endl;
using std::cout;
using std::cerr;

template <class A>
inline void die(const A &arg) {
cerr << "EXCEPTION (" << arg << ")\n";
exit(255);
}

// get current time down to usec precision as a double

inline double now() {
Expand Down

0 comments on commit f37e299

Please sign in to comment.