Skip to content
This repository has been archived by the owner on Jan 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #196 from ndawe/master
Browse files Browse the repository at this point in the history
Work around thread safety issue in TMVA::Reader
  • Loading branch information
ndawe committed May 5, 2015
2 parents 7b6d952 + df46bc8 commit a24e1c8
Show file tree
Hide file tree
Showing 8 changed files with 2,648 additions and 1,619 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ python:
- "2.7"
- "3.4"
env:
- ROOT=v5-34-18
- ROOT=v5-34-18 NOTMVA=1
- ROOT=master
- ROOT=master COVERAGE=1
install: source ci/install.sh
script: bash ci/test.sh
after_success:
Expand Down
5 changes: 3 additions & 2 deletions ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ time make install-user
time make test-installed

# Run tests in the local directory with coverage
if [ -z ${NOTMVA+x} ]; then
# TMVA is included in this build, so run the coverage
if [ ! -z ${COVERAGE+x} ] && [ -z ${NOTMVA+x} ]; then
# COVERAGE is set and TMVA is included in this build
# so run the coverage
time make test-coverage </dev/null
fi
2 changes: 1 addition & 1 deletion root_numpy/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
|_| \___/ \___/ \__|___|_| |_|\__,_|_| |_| |_| .__/ \__, | {0}
|_____| |_| |___/
"""
__version__ = '4.1.0.dev0'
__version__ = '4.1.1.dev0'
__doc__ = __doc__.format(__version__) # pylint:disable=redefined-builtin
44 changes: 15 additions & 29 deletions root_numpy/tmva/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,26 @@ def evaluate_reader(reader, name, events):
"""
if not isinstance(reader, TMVA.Reader):
raise TypeError("reader must be a TMVA.Reader instance")
method = reader.FindMVA(name)
if not method:
events = np.ascontiguousarray(events, dtype=np.float64)
if events.ndim == 1:
# convert to 2D
events = events[:, np.newaxis]
elif events.ndim != 2:
raise ValueError(
"method '{0}' is not booked in this reader".format(name))
return evaluate_method(method, events)
"events must be a two-dimensional array "
"with one event per row")
return _libtmvanumpy.evaluate_reader(ROOT.AsCObject(reader), name, events)


def evaluate_method(method, events):
"""Evaluate a TMVA::MethodBase over a NumPy array.
.. warning:: TMVA::Reader has known problems with thread safety in versions
of ROOT earlier than 6.03. There will potentially be a crash if you call
``method = reader.FindMVA(name)`` in Python and then pass this
``method`` here. Consider using ``evaluate_reader`` instead if you are
affected by this crash.
Parameters
----------
method : TMVA::MethodBase
Expand Down Expand Up @@ -75,28 +85,4 @@ def evaluate_method(method, events):
raise ValueError(
"events must be a two-dimensional array "
"with one event per row")
if events.shape[1] != method.GetNVariables():
raise ValueError(
"this method was trained with events containing "
"{0} variables, but these events contain {1} variables".format(
method.GetNVariables(), events.shape[1]))
analysistype = method.GetAnalysisType()
if analysistype == TMVA.Types.kClassification:
return _libtmvanumpy.evaluate_twoclass(
ROOT.AsCObject(method), events)
elif analysistype == TMVA.Types.kMulticlass:
n_classes = method.DataInfo().GetNClasses()
if n_classes < 2:
raise AssertionError("there must be at least two classes")
return _libtmvanumpy.evaluate_multiclass(
ROOT.AsCObject(method), events, n_classes)
elif analysistype == TMVA.Types.kRegression:
n_targets = method.DataInfo().GetNTargets()
if n_targets < 1:
raise AssertionError("there must be at least one regression target")
output = _libtmvanumpy.evaluate_regression(
ROOT.AsCObject(method), events, n_targets)
if n_targets == 1:
return np.ravel(output)
return output
raise AssertionError("the analysis type of this method is not supported")
return _libtmvanumpy.evaluate_method(ROOT.AsCObject(method), events)
30 changes: 27 additions & 3 deletions root_numpy/tmva/src/TMVA.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,42 @@ cdef extern from "TMVA/Types.h" namespace "TMVA":
kTraining "TMVA::Types::kTraining"
kTesting "TMVA::Types::kTesting"

ctypedef enum EAnalysisType "TMVA::Types::EAnalysisType":
kClassification "TMVA::Types::kClassification"
kRegression "TMVA::Types::kRegression"
kMulticlass "TMVA::Types::kMulticlass"

cdef extern from "TMVA/Event.h" namespace "TMVA":
cdef cppclass Event:
Event(vector[float]& features, unsigned int theclass)
void SetVal(unsigned int ivar, float value)

cdef extern from "TMVA/Factory.h" namespace "TMVA":
cdef cppclass Factory:
void AddEvent(string& classname, ETreeType treetype, vector[double]& event, double weight)
cdef extern from "TMVA/DataSetInfo.h" namespace "TMVA":
cdef cppclass DataSetInfo:
unsigned int GetNClasses()
unsigned int GetNVariables()
unsigned int GetNTargets()
vector[string] GetListOfVariables()

cdef extern from "TMVA/IMethod.h" namespace "TMVA":
cdef cppclass IMethod:
pass

cdef extern from "TMVA/MethodBase.h" namespace "TMVA":
cdef cppclass MethodBase:
EAnalysisType GetAnalysisType()
DataSetInfo DataInfo()
unsigned int GetNVariables()
unsigned int GetNTargets()
double GetMvaValue()
vector[float] GetMulticlassValues()
vector[float] GetRegressionValues()
Event* fTmpEvent

cdef extern from "TMVA/Factory.h" namespace "TMVA":
cdef cppclass Factory:
void AddEvent(string& classname, ETreeType treetype, vector[double]& event, double weight)

cdef extern from "TMVA/Reader.h" namespace "TMVA":
cdef cppclass Reader:
IMethod* FindMVA(string name)
Loading

0 comments on commit a24e1c8

Please sign in to comment.