Skip to content

Commit

Permalink
Adjust DeepFlavourTFJetTagsProducer to use TF graphs from protobuffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga authored and pablodecm committed Nov 15, 2017
1 parent 05706d7 commit 765a99c
Showing 1 changed file with 19 additions and 30 deletions.
49 changes: 19 additions & 30 deletions RecoBTag/DeepFlavour/plugins/DeepFlavourTFJetTagsProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@
#include "tensor_fillers.h"

// Declaration of the data structure that is hold by the edm::GlobalCache.
// In TensorFlow, the computational graph is stored in a stateless meta graph object which can be
// shared by multiple session instances which handle the initialization of variables related to the
// meta graph. Following this approach in CMSSW, a meta graph should be stored in a GlobalCache
// which can be accesses by sessions owned by multiple stream module copies. Instead of using only
// the plain meta graph, we make use of a cache struct that can be extended in the future if nedded.
// In addition, the meta graph is protected via std::atomic, which should not affect the performance
// as it is only accessed in the module constructor and not in the actual produce loop.
// In TensorFlow, the computational graph is stored in a stateless graph object which can be shared
// by multiple session instances which handle the initialization of variables related to the graph.
// Following this approach in CMSSW, a graph should be stored in a GlobalCache which can be accessed
// by sessions owned by multiple stream module copies. Instead of using only the plain graph, we
// make use of a cache struct that can be extended in the future if nedded. In addition, the graph
// is protected via std::atomic, which should not affect the performance as it is only accessed in
// the module constructor and not in the actual produce loop.
struct DeepFlavourTFCache {
DeepFlavourTFCache() : metaGraph(nullptr)
{
DeepFlavourTFCache() : graphDef(nullptr) {
}

std::atomic<tensorflow::MetaGraphDef*> metaGraph;
std::atomic<tensorflow::GraphDef*> graphDef;
};

class DeepFlavourTFJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCache<DeepFlavourTFCache>> {
Expand Down Expand Up @@ -74,7 +73,6 @@ class DeepFlavourTFJetTagsProducer : public edm::stream::EDProducer<edm::GlobalC

// session for TF evaluation
tensorflow::Session* session_;

// vector of learning phase tensors, i.e., boolean scalar tensors pointing to false
std::vector<tensorflow::Tensor> lp_tensors_;
// flag to evaluate model batch or jet by jet
Expand All @@ -97,9 +95,7 @@ DeepFlavourTFJetTagsProducer::DeepFlavourTFJetTagsProducer(const edm::ParameterS
tensorflow::setThreading(sessionOptions, nThreads, singleThreadPool);

// create the session using the meta graph from the cache
edm::FileInPath graphPath(iConfig.getParameter<edm::FileInPath>("graph_path"));
std::string exportDir = graphPath.fullPath().substr(0, graphPath.fullPath().find_last_of("/"));
session_ = tensorflow::createSession(cache->metaGraph, exportDir, sessionOptions);
session_ = tensorflow::createSession(cache->graphDef, sessionOptions);

// get output names from flav_table
const auto & flav_pset = iConfig.getParameter<edm::ParameterSet>("flav_table");
Expand Down Expand Up @@ -140,11 +136,11 @@ void DeepFlavourTFJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptio
desc.add<std::vector<std::string>>("input_names",
{ "input_1", "input_2", "input_3", "input_4", "input_5" });
desc.add<edm::FileInPath>("graph_path",
edm::FileInPath("RecoBTag/Combined/data/DeepFlavourV01_C_PtCut/saved_model.pb"));
edm::FileInPath("RecoBTag/Combined/data/DeepFlavourV01_C_PtCut/constant_graph.pb"));
desc.add<std::vector<std::string>>("lp_names",
{"globals_input_batchnorm/keras_learning_phase"});
{ "globals_input_batchnorm/keras_learning_phase" });
desc.add<std::vector<std::string>>("output_names",
{ "ID_pred/Softmax", "regression_pred/BiasAdd", });
{ "ID_pred/Softmax", "regression_pred/BiasAdd" });
{
edm::ParameterSetDescription psd0;
psd0.add<std::vector<unsigned int>>("probb", {0});
Expand All @@ -170,27 +166,20 @@ std::unique_ptr<DeepFlavourTFCache> DeepFlavourTFJetTagsProducer::initializeGlob
// set the tensorflow log level to error
tensorflow::setLogging("3");

// build the exportDir from graph_path
edm::FileInPath graphPath(iConfig.getParameter<edm::FileInPath>("graph_path"));
std::string exportDir = graphPath.fullPath().substr(0, graphPath.fullPath().find_last_of("/"));

// get threading config and build session options
size_t nThreads = iConfig.getParameter<unsigned int>("nThreads");
std::string singleThreadPool = iConfig.getParameter<std::string>("singleThreadPool");
tensorflow::SessionOptions sessionOptions;
tensorflow::setThreading(sessionOptions, nThreads, singleThreadPool);
// get the pb file
std::string pbFile = iConfig.getParameter<edm::FileInPath>("graph_path").fullPath();

// create the cache instance and attach the meta graph to it
// load the graph def and save it in the cache
DeepFlavourTFCache* cache = new DeepFlavourTFCache();
cache->metaGraph = tensorflow::loadMetaGraph(exportDir, "serve", sessionOptions);
cache->graphDef = tensorflow::loadGraphDef(pbFile);

return std::unique_ptr<DeepFlavourTFCache>(cache);
}

void DeepFlavourTFJetTagsProducer::globalEndJob(const DeepFlavourTFCache* cache)
{
if (cache->metaGraph != nullptr) {
delete cache->metaGraph;
if (cache->graphDef != nullptr) {
delete cache->graphDef;
}
}

Expand Down

0 comments on commit 765a99c

Please sign in to comment.