Skip to content

Commit

Permalink
Roc curve weights + update (#525)
Browse files Browse the repository at this point in the history
* Add support for event weights in ROC calculations

* Update authors for ROCCurve

* Make TMVA factory use event weights

* Update authors of TMVA Factory

* Fix error in implementation of ROCCalc::ComputeSensitivity.

Also fixes axis labels and drawing direction to plot specificity on
x-axis and sensitivity on y-axis as is one standard.

* Fix ROC curve oscilation bug

Due to interpolation setting the ROC curve could sometimes be
non-monotonically decreasing. Changed to linear interpolation.

* Add ROC constructor for separate signal and background vectors

* Remove caching from ROCCurve::GetROCCurve

This since we might want to recalculate it with different
number of divisions.

* Fix ROC Curve crash when num_points < 2

* Format for clang-format

* Changes as discussed with L. Moneta

- Avoid unnecessary array copy
- Use reserve() to preallocate size of event arrays
- Use [] indexing instead of .at() since we assert sizes
- Use Double_t instead of Float_t for integral calc
- Flip axes of ROC plot

* Add documentation for ROCCurve public methods

* Fixes for clang-format

* More clang-format

* Last time with the doc clang-format
  • Loading branch information
ashlaban authored and lmoneta committed Apr 26, 2017
1 parent 26f4bbd commit 6d5a8ae
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 126 deletions.
46 changes: 28 additions & 18 deletions tmva/tmva/inc/TMVA/ROCCurve.h
@@ -1,5 +1,5 @@
// @(#)root/tmva $Id$
// Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer
// Author: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer, Kim Albertsson

/**********************************************************************************
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
Expand All @@ -13,6 +13,7 @@
* Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
* Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
* Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
* Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
* *
* Copyright (c) 2015: *
* CERN, Switzerland *
Expand Down Expand Up @@ -41,30 +42,39 @@ class TGraph;

namespace TMVA {

class MsgLogger;
class MsgLogger;

class ROCCurve {

class ROCCurve {
public:
ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets,
const std::vector<Float_t> &mvaWeights);

public:
ROCCurve( const std::vector<Float_t> & mvaS, const std::vector<Bool_t> & mvat);
ROCCurve(const std::vector<Float_t> &mvaValues, const std::vector<Bool_t> &mvaTargets);

~ROCCurve();
ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground,
const std::vector<Float_t> &mvaSignalWeights, const std::vector<Float_t> &mvaBackgroundWeights);

ROCCurve(const std::vector<Float_t> &mvaSignal, const std::vector<Float_t> &mvaBackground);

Double_t GetROCIntegral();
TGraph* GetROCCurve(const UInt_t points=100);//n divisions = #points -1
~ROCCurve();

private:
void EpsilonCount();
mutable MsgLogger* fLogger; //! message logger
MsgLogger& Log() const { return *fLogger; }
TGraph *fGraph;
std::vector<Float_t> fMvaS;
std::vector<Float_t> fMvaB;
std::vector<Float_t> fEpsilonSig;
std::vector<Float_t> fEpsilonBgk;
Double_t GetROCIntegral(const UInt_t points = 41);
TGraph *GetROCCurve(const UInt_t points = 100); // n divisions = #points -1

};
private:
mutable MsgLogger *fLogger; //! message logger
MsgLogger &Log() const { return *fLogger; }

TGraph *fGraph;

std::vector<Float_t> fMvaSignal;
std::vector<Float_t> fMvaBackground;
std::vector<Float_t> fMvaSignalWeights;
std::vector<Float_t> fMvaBackgroundWeights;

std::vector<Double_t> ComputeSensitivity(const UInt_t num_points);
std::vector<Double_t> ComputeSpecificity(const UInt_t num_points);
};
}
#endif
59 changes: 36 additions & 23 deletions tmva/tmva/src/Factory.cxx
@@ -1,6 +1,6 @@
// @(#)Root/tmva $Id$
// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
// Updated by: Omar Zapata
// Updated by: Omar Zapata, Kim Albertsson
/**********************************************************************************
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
* Package: TMVA *
Expand All @@ -21,6 +21,7 @@
* Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
* Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
* Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
* Kim Albertsson <kim.albertsson@cern.ch> - LTU & CERN *
* *
* Copyright (c) 2005-2015: *
* CERN, Switzerland *
Expand Down Expand Up @@ -761,37 +762,49 @@ TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, B
return nullptr;
}

std::vector<Float_t> mvaRes;
std::vector<Bool_t> mvaResTypes;
TMVA::ROCCurve *rocCurve;
TGraph *graph;
TMVA::ROCCurve *rocCurve = nullptr;
TGraph *graph = nullptr;

if (this->fAnalysisType == Types::kClassification) {

std::vector<Float_t> * rawMvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
mvaRes = *rawMvaRes;

std::vector<Bool_t> * rawMvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
mvaResTypes = *rawMvaResType;
std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
std::vector<Bool_t> *mvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
std::vector<Float_t> mvaResWeights;

auto eventCollection = dataset->GetEventCollection();
mvaResWeights.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResWeights.push_back(ev->GetWeight());
}

rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResType, mvaResWeights);

} else if (this->fAnalysisType == Types::kMulticlass) {
std::vector<Float_t> mvaRes;
std::vector<Bool_t> mvaResTypes;
std::vector<Float_t> mvaResWeights;

std::vector<std::vector<Float_t>> * rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();

// Vector transpose due to values being stored as
// [ [0, 1, 2], [0, 1, 2], ... ]
// in ResultsMulticlass::GetValueVector.
for (auto & item : *rawMvaRes) {
mvaRes.push_back( item[iClass] );
mvaRes.reserve(rawMvaRes->size());
for (auto item : *rawMvaRes) {
mvaRes.push_back(item[iClass]);
}

auto eventCollection = dataset->GetEventCollection();
mvaResTypes.reserve(eventCollection.size());
mvaResWeights.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResTypes.push_back( ev->GetClass() == iClass );
mvaResTypes.push_back(ev->GetClass() == iClass);
mvaResWeights.push_back(ev->GetWeight());
}

rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
}

rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes);

if ( ! rocCurve ) {
Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
return nullptr;
Expand All @@ -801,9 +814,9 @@ TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, B
delete rocCurve;

if(setTitles) {
graph->GetYaxis()->SetTitle("Background Rejection");
graph->GetXaxis()->SetTitle("Signal Efficiency");
graph->SetTitle( Form( "Background Rejection vs. Signal Efficiency (%s)", theMethodName.Data() ) );
graph->GetYaxis()->SetTitle("Background rejection (Specificity)");
graph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");
graph->SetTitle(Form("Signal efficiency vs. Background rejection (%s)", theMethodName.Data()));
}

return graph;
Expand Down Expand Up @@ -919,14 +932,14 @@ TCanvas * TMVA::Factory::GetROCCurve(TString datasetname, UInt_t iClass)
TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass);

if ( multigraph ) {
multigraph->Draw("AC");
multigraph->Draw("AL");

multigraph->GetYaxis()->SetTitle("Background Rejection");
multigraph->GetXaxis()->SetTitle("Signal Efficiency");
multigraph->GetYaxis()->SetTitle("Background rejection (Specificity)");
multigraph->GetXaxis()->SetTitle("Signal efficiency (Sensitivity)");

TString titleString = Form( "Background Rejection vs. Signal Efficiency");
TString titleString = Form("Signal efficiency vs. Background rejection");
if (this->fAnalysisType == Types::kMulticlass) {
titleString = Form( "Background Rejection vs. Signal Efficiency (Class=%i)", iClass );
titleString = Form("%s (Class=%i)", titleString.Data(), iClass);
}

// Workaround for TMultigraph not drawing title correctly.
Expand Down

0 comments on commit 6d5a8ae

Please sign in to comment.