forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
predictor_utils.cc
82 lines (71 loc) · 2.38 KB
/
predictor_utils.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include "caffe2/predictor/predictor_utils.h"
#include "caffe2/core/blob.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/predictor_consts.pb.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
namespace predictor_utils {
CAFFE2_API const NetDef& getNet(
const MetaNetDef& def,
const std::string& name) {
for (const auto& n : def.nets()) {
if (n.key() == name) {
return n.value();
}
}
CAFFE_THROW("Net not found: ", name);
}
std::unique_ptr<MetaNetDef> extractMetaNetDef(
db::Cursor* cursor,
const std::string& key) {
CAFFE_ENFORCE(cursor);
if (cursor->SupportsSeek()) {
cursor->Seek(key);
}
for (; cursor->Valid(); cursor->Next()) {
if (cursor->key() != key) {
continue;
}
// We've found a match. Parse it out.
BlobProto proto;
CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
Blob blob;
DeserializeBlob(proto, &blob);
CAFFE_ENFORCE(blob.template IsType<string>());
auto def = caffe2::make_unique<MetaNetDef>();
CAFFE_ENFORCE(def->ParseFromString(blob.template Get<string>()));
return def;
}
CAFFE_THROW("Failed to find in db the key: ", key);
}
std::unique_ptr<MetaNetDef> runGlobalInitialization(
std::unique_ptr<db::DBReader> db,
Workspace* master) {
CAFFE_ENFORCE(db.get());
auto* cursor = db->cursor();
auto metaNetDef = extractMetaNetDef(
cursor, PredictorConsts::default_instance().meta_net_def());
if (metaNetDef->has_modelinfo()) {
CAFFE_ENFORCE(
metaNetDef->modelinfo().predictortype() ==
PredictorConsts::default_instance().single_predictor(),
"Can only load single predictor");
}
VLOG(1) << "Extracted meta net def";
const auto globalInitNet = getNet(
*metaNetDef, PredictorConsts::default_instance().global_init_net_type());
VLOG(1) << "Global init net: " << ProtoDebugString(globalInitNet);
// Now, pass away ownership of the DB into the master workspace for
// use by the globalInitNet.
master->CreateBlob(PredictorConsts::default_instance().predictor_dbreader())
->Reset(db.release());
// Now, with the DBReader set, we can run globalInitNet.
CAFFE_ENFORCE(
master->RunNetOnce(globalInitNet),
"Failed running the globalInitNet: ",
ProtoDebugString(globalInitNet));
return metaNetDef;
}
} // namespace predictor_utils
} // namespace caffe2