-
Notifications
You must be signed in to change notification settings - Fork 2
/
testGraphLoading.cc
79 lines (62 loc) · 2.37 KB
/
testGraphLoading.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/*
* Tests for loading graphs via the converted protobuf files.
* Based on TensorFlow 2.1.
* For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
*
* Author: Marcel Rieger
*/
#include <stdexcept>
#include <cppunit/extensions/HelperMacros.h>
#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
#include "testBase.h"
class testGraphLoading : public testBase {
CPPUNIT_TEST_SUITE(testGraphLoading);
CPPUNIT_TEST(checkAll);
CPPUNIT_TEST_SUITE_END();
public:
std::string pyScript() const override;
void checkAll() override;
};
CPPUNIT_TEST_SUITE_REGISTRATION(testGraphLoading);
std::string testGraphLoading::pyScript() const { return "createconstantgraph.py"; }
void testGraphLoading::checkAll() {
std::string pbFile = dataPath_ + "/constantgraph.pb";
// load the graph
tensorflow::setLogging();
tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
CPPUNIT_ASSERT(graphDef != nullptr);
// create a new session and add the graphDef
tensorflow::Session* session = tensorflow::createSession(graphDef);
CPPUNIT_ASSERT(session != nullptr);
// check for exception
CPPUNIT_ASSERT_THROW(tensorflow::createSession(nullptr), cms::Exception);
// example evaluation
tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
float* d = input.flat<float>().data();
for (size_t i = 0; i < 10; i++, d++) {
*d = float(i);
}
tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
scale.scalar<float>()() = 1.0;
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status status = session->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
if (!status.ok()) {
std::cout << status.ToString() << std::endl;
CPPUNIT_ASSERT(false);
}
// check the output
CPPUNIT_ASSERT(outputs.size() == 1);
std::cout << outputs[0].DebugString() << std::endl;
CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
// run again using the convenience helper
outputs.clear();
tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
CPPUNIT_ASSERT(outputs.size() == 1);
std::cout << outputs[0].DebugString() << std::endl;
CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
// check for exception
CPPUNIT_ASSERT_THROW(tensorflow::run(session, {{"foo", input}}, {"output"}, &outputs), cms::Exception);
// cleanup
CPPUNIT_ASSERT(tensorflow::closeSession(session));
delete graphDef;
}