Skip to content

Commit

Permalink
Merge pull request BVLC#6123 from IlyaOvodov/master
Browse files Browse the repository at this point in the history
"weights" added to solver parameters, "snapshot_prefix" field default initialization
  • Loading branch information
Noiredd committed Feb 12, 2018
2 parents 87e1512 + 6fa4c62 commit a44c444
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 18 deletions.
18 changes: 16 additions & 2 deletions src/caffe/proto/caffe.proto
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
// SolverParameter next available ID: 43 (last added: weights)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -186,7 +186,11 @@ message SolverParameter {
optional float clip_gradients = 35 [default = -1];

optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// The prefix for the snapshot.
// If not set then is replaced by prototxt file path without extention.
// If is set to directory then is augmented by prototxt file name
// without extention.
optional string snapshot_prefix = 15;
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [default = false];
Expand Down Expand Up @@ -241,6 +245,16 @@ message SolverParameter {

// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [default = true];

// Path to caffemodel file(s) with pretrained weights to initialize finetuning.
// Tha same as command line --weights parameter for caffe train command.
// If command line --weights parameter if specified, it has higher priority
// and owerwrites this one(s).
// If --snapshot command line parameter is specified, this one(s) are ignored.
// If several model files are expected, they can be listed in a one
// weights parameter separated by ',' (like in a command string) or
// in repeated weights parameters separately.
repeated string weights = 42;
}

// A message that stores the solver snapshots
Expand Down
21 changes: 21 additions & 0 deletions src/caffe/solver.cpp
Expand Up @@ -3,6 +3,7 @@
#include <string>
#include <vector>

#include "boost/algorithm/string.hpp"
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
Expand Down Expand Up @@ -59,6 +60,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
current_step_ = 0;
}

// Load weights from the caffemodel(s) specified in "weights" solver parameter
// into the train and test nets.
template <typename Dtype>
void LoadNetWeights(shared_ptr<Net<Dtype> > net,
const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(","));
for (int i = 0; i < model_names.size(); ++i) {
boost::trim(model_names[i]);
LOG(INFO) << "Finetuning from " << model_names[i];
net->CopyTrainedLayersFrom(model_names[i]);
}
}

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
Expand Down Expand Up @@ -98,6 +113,9 @@ void Solver<Dtype>::InitTrainNet() {
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
net_.reset(new Net<Dtype>(net_param));
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(net_, param_.weights(w_idx));
}
}

template <typename Dtype>
Expand Down Expand Up @@ -173,6 +191,9 @@ void Solver<Dtype>::InitTestNets() {
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(test_nets_[i], param_.weights(w_idx));
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/caffe/test/test_upgrade_proto.cpp
Expand Up @@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
for (int i = 0; i < 6; ++i) {
const string& input_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand All @@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
"solver_type: " + std::string(old_type_vec[i]) + " ";
const string& expected_output_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand Down
21 changes: 21 additions & 0 deletions src/caffe/util/upgrade_proto.cpp
Expand Up @@ -2,6 +2,8 @@
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>

#include <boost/filesystem.hpp>

#include <map>
#include <string>

Expand Down Expand Up @@ -1095,12 +1097,31 @@ bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
return success;
}

// Replaces snapshot_prefix of SolverParameter if it is not specified
// or is set to directory
void UpgradeSnapshotPrefixProperty(const string& param_file,
SolverParameter* param) {
using boost::filesystem::path;
using boost::filesystem::is_directory;
if (!param->has_snapshot_prefix()) {
param->set_snapshot_prefix(path(param_file).replace_extension().string());
LOG(INFO) << "snapshot_prefix was not specified and is set to "
+ param->snapshot_prefix();
} else if (is_directory(param->snapshot_prefix())) {
param->set_snapshot_prefix((path(param->snapshot_prefix()) /
path(param_file).stem()).string());
LOG(INFO) << "snapshot_prefix was a directory and is replaced to "
+ param->snapshot_prefix();
}
}

// Read parameters from a file into a SolverParameter proto message.
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
SolverParameter* param) {
CHECK(ReadProtoFromTextFile(param_file, param))
<< "Failed to parse SolverParameter file: " << param_file;
UpgradeSolverAsNeeded(param_file, param);
UpgradeSnapshotPrefixProperty(param_file, param);
}

} // namespace caffe
23 changes: 7 additions & 16 deletions tools/caffe.cpp
Expand Up @@ -146,20 +146,6 @@ int device_query() {
}
RegisterBrewFunction(device_query);

// Load the weights from the specified caffemodel(s) into the train and
// test nets.
void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(",") );
for (int i = 0; i < model_names.size(); ++i) {
LOG(INFO) << "Finetuning from " << model_names[i];
solver->net()->CopyTrainedLayersFrom(model_names[i]);
for (int j = 0; j < solver->test_nets().size(); ++j) {
solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
}
}
}

// Translate the signal effect the user specified on the command-line to the
// corresponding enumeration.
caffe::SolverAction::Enum GetRequestedAction(
Expand Down Expand Up @@ -233,6 +219,13 @@ int train() {
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));

if (FLAGS_snapshot.size()) {
solver_param.clear_weights();
} else if (FLAGS_weights.size()) {
solver_param.clear_weights();
solver_param.add_weights(FLAGS_weights);
}

shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

Expand All @@ -241,8 +234,6 @@ int train() {
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
CopyLayers(solver.get(), FLAGS_weights);
}

LOG(INFO) << "Starting Optimization";
Expand Down

0 comments on commit a44c444

Please sign in to comment.