Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix in MethodPyKeras for using the new Tensorflow version 2 #4977

Merged
merged 1 commit into from Feb 13, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 24 additions & 6 deletions tmva/pymva/src/MethodPyKeras.cxx
Expand Up @@ -171,15 +171,28 @@ void MethodPyKeras::ProcessOptions() {
Log() << kINFO << "Using TensorFlow backend - setting special configuration options " << Endl;
PyRunString("import tensorflow as tf");
PyRunString("from keras.backend import tensorflow_backend as K");

// check tensorflow version
PyRunString("tf_major_version = int(tf.__version__.split('.')[0])");
//PyRunString("print(tf.__version__,'major is ',tf_major_version)");
PyObject *pyTfVersion = PyDict_GetItemString(fLocalNS, "tf_major_version");
int tfVersion = PyLong_AsLong(pyTfVersion);
Log() << kINFO << "Using Tensorflow version " << tfVersion << Endl;

// use different naming in tf2 for ConfigProto and Session
TString configProto = (tfVersion >= 2) ? "tf.compat.v1.ConfigProto" : "tf.ConfigProto";
TString session = (tfVersion >= 2) ? "tf.compat.v1.Session" : "tf.Session";

// in case specify number of threads
int num_threads = fNumThreads;
if (num_threads > 0) {
Log() << kINFO << "Setting the CPU number of threads = " << num_threads << Endl;
PyRunString(TString::Format("session_conf = tf.ConfigProto(intra_op_parallelism_threads=%d,inter_op_parallelism_threads=%d)",
num_threads,num_threads));

PyRunString(TString::Format("session_conf = %s(intra_op_parallelism_threads=%d,inter_op_parallelism_threads=%d)",
configProto.Data(), num_threads,num_threads));
}
else
PyRunString("session_conf = tf.ConfigProto()");
PyRunString(TString::Format("session_conf = %s()",configProto.Data()));

// applying GPU options such as allow_growth=True to avoid allocating all memory on GPU
// that prevents running later TMVA-GPU
Expand All @@ -191,8 +204,13 @@ void MethodPyKeras::ProcessOptions() {
PyRunString(TString::Format("session_conf.gpu_options.%s", optlist->At(item)->GetName()));
}
}
PyRunString("sess = tf.Session(config=session_conf)");
PyRunString("K.set_session(sess)");
PyRunString(TString::Format("sess = %s(config=session_conf)", session.Data()));

if (tfVersion < 2) {
PyRunString("K.set_session(sess)");
} else {
PyRunString("tf.compat.v1.keras.backend.set_session(sess)");
}
}
else {
if (fNumThreads > 0)
Expand Down Expand Up @@ -458,7 +476,7 @@ void MethodPyKeras::Train() {
const char *stra_name = PyBytes_AsString(stra);
// need to add string delimiter for Python2
TString sname = TString::Format("'%s'",stra_name);
const char * name = sname.Data();
const char * name = sname.Data();
#else // for Python3
PyObject* repr = PyObject_Repr(stra);
PyObject* str = PyUnicode_AsEncodedString(repr, "utf-8", "~E~");
Expand Down