Skip to content

Commit

Permalink
Add a new gConfig.IONames fielf, fWrightFileDIrPrefix. This allows to…
Browse files Browse the repository at this point in the history
… have add a prefix for the directory to store the weights. By default they are stored in the directory starting with the dataset name.

With the prefix they will be stored in weightfile_prefix/dataset_name/weight_file_name.
This fixes ROOT-8887
  • Loading branch information
lmoneta committed Sep 27, 2019
1 parent 78f62f6 commit 48d5751
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 57 deletions.
17 changes: 9 additions & 8 deletions tmva/tmva/inc/TMVA/Config.h
@@ -1,4 +1,4 @@
// @(#)root/tmva $Id$
// @(#)root/tmva $Id$
// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss

/**********************************************************************************
Expand Down Expand Up @@ -79,7 +79,7 @@ namespace TMVA {
// ROOT::TSequentialExecutor &GetSeqExecutor() { return *fSeqfPool; }
#endif
/// Get executor class for multi-thread usage
/// In case when MT is not enabled will return a serial executor
/// In case when MT is not enabled will return a serial executor
Executor & GetThreadExecutor() { return fExecutor; }

/// Enable MT in TMVA (by default is on when ROOT::EnableImplicitMT() is set
Expand All @@ -90,7 +90,7 @@ namespace TMVA {

///Check if IMT is enabled
Bool_t IsMTEnabled() const { return fExecutor.GetPoolSize() > 1; }

public:

class VariablePlotting;
Expand Down Expand Up @@ -118,13 +118,14 @@ namespace TMVA {
class IONames {

public:

// this is name of weight file directory
TString fWeightFileDirPrefix;
TString fWeightFileDir;
TString fWeightFileExtension;
TString fOptionsReferenceFileDir;
} fIONames; // Customisable weight file properties


private:

// private constructor
Expand All @@ -136,7 +137,7 @@ namespace TMVA {
static std::atomic<Config*> fgConfigPtr;
#else
static Config* fgConfigPtr;
#endif
#endif
private:

#if __cplusplus > 199711L
Expand All @@ -154,7 +155,7 @@ namespace TMVA {
#endif
mutable MsgLogger* fLogger; // message logger
MsgLogger& Log() const { return *fLogger; }

ClassDef(Config,0); // Singleton class for global configuration settings
};

Expand Down
2 changes: 1 addition & 1 deletion tmva/tmva/src/Config.cxx
Expand Up @@ -73,6 +73,7 @@ TMVA::Config::Config() :
fVariablePlotting.fUsePaperStyle = 0;

// IO names
fIONames.fWeightFileDirPrefix = "";
fIONames.fWeightFileDir = "weights";
fIONames.fWeightFileExtension = "weights";
fIONames.fOptionsReferenceFileDir = "optionInfo";
Expand Down Expand Up @@ -118,4 +119,3 @@ TMVA::Config& TMVA::Config::Instance()
return fgConfigPtr ? *fgConfigPtr :*(fgConfigPtr = new Config());
#endif
}

96 changes: 53 additions & 43 deletions tmva/tmva/src/Factory.cxx
Expand Up @@ -397,12 +397,17 @@ TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString t
conf->DeclareOptionRef( boostNum = 0, "Boost_num",
"Number of times the classifier will be boosted" );
conf->ParseOptions();
delete conf;
TString fFileDir;
delete conf; // this is name of weight file directory (weigh)
TString fileDir;
if(fModelPersistence)
{
fFileDir=loader->GetName();
fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
// find prefix in fWeightFileDir;
TString prefix = gConfig().GetIONames().fWeightFileDirPrefix;
fileDir = prefix;
if (!prefix.IsNull())
if (fileDir[fileDir.Length()-1] != '/') fileDir += "/";
fileDir += loader->GetName();
fileDir += "/" + gConfig().GetIONames().fWeightFileDir;
}
// initialize methods
IMethod* im;
Expand All @@ -419,7 +424,7 @@ TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString t
Log() << kFATAL << "Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST

if (fModelPersistence)
methBoost->SetWeightFileDir(fFileDir);
methBoost->SetWeightFileDir(fileDir);
methBoost->SetModelPersistence(fModelPersistence);
methBoost->SetBoostedMethodName(theMethodName); // DSMTEST divided into two lines
methBoost->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
Expand All @@ -436,7 +441,7 @@ TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString t
if (!methCat) // DSMTEST
Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl; // DSMTEST

if(fModelPersistence) methCat->SetWeightFileDir(fFileDir);
if(fModelPersistence) methCat->SetWeightFileDir(fileDir);
methCat->SetModelPersistence(fModelPersistence);
methCat->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager(); // DSMTEST
methCat->SetFile(fgTargetFile);
Expand All @@ -460,7 +465,7 @@ TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString t
return 0;
}

if(fModelPersistence) method->SetWeightFileDir(fFileDir);
if(fModelPersistence) method->SetWeightFileDir(fileDir);
method->SetModelPersistence(fModelPersistence);
method->SetAnalysisType( fAnalysisType );
method->SetupMethod();
Expand Down Expand Up @@ -494,18 +499,18 @@ TMVA::MethodBase* TMVA::Factory::BookMethod(TMVA::DataLoader *loader, Types::EMV

////////////////////////////////////////////////////////////////////////////////
/// Adds an already constructed method to be managed by this factory.
///
///
/// \note Private.
/// \note Know what you are doing when using this method. The method that you
/// are loading could be trained already.
///
/// are loading could be trained already.
///

TMVA::MethodBase* TMVA::Factory::BookMethodWeightfile(DataLoader *loader, TMVA::Types::EMVA methodType, const TString &weightfile)
{
TString datasetname = loader->GetName();
std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
DataSetInfo &dsi = loader->GetDataSetInfo();

IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile );
MethodBase *method = (dynamic_cast<MethodBase*>(im));

Expand All @@ -515,13 +520,19 @@ TMVA::MethodBase* TMVA::Factory::BookMethodWeightfile(DataLoader *loader, TMVA::
Log() << kERROR << "Cannot handle category methods for now." << Endl;
}

TString fFileDir;
TString fileDir;
if(fModelPersistence) {
fFileDir=loader->GetName();
fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
// find prefix in fWeightFileDir;
TString prefix = gConfig().GetIONames().fWeightFileDirPrefix;
fileDir = prefix;
if (!prefix.IsNull())
if (fileDir[fileDir.Length() - 1] != '/')
fileDir += "/";
fileDir=loader->GetName();
fileDir+="/"+gConfig().GetIONames().fWeightFileDir;
}

if(fModelPersistence) method->SetWeightFileDir(fFileDir);
if(fModelPersistence) method->SetWeightFileDir(fileDir);
method->SetModelPersistence(fModelPersistence);
method->SetAnalysisType( fAnalysisType );
method->SetupMethod();
Expand Down Expand Up @@ -887,14 +898,14 @@ Double_t TMVA::Factory::GetROCIntegral(TString datasetname, TString theMethodNam
}

////////////////////////////////////////////////////////////////////////////////
/// Argument iClass specifies the class to generate the ROC curve in a
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///
///
/// Returns a ROC graph for a given method, or nullptr on error.
///
/// Note: Evaluation of the given method must have been run prior to ROC
/// Note: Evaluation of the given method must have been run prior to ROC
/// generation through Factory::EvaluateAllMetods.
///
///
/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
/// and the others considered background. This is ok in binary classification
/// but in in multi class classification, the ROC surface is an N dimensional
Expand All @@ -906,14 +917,14 @@ TGraph* TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bo
}

////////////////////////////////////////////////////////////////////////////////
/// Argument iClass specifies the class to generate the ROC curve in a
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///
///
/// Returns a ROC graph for a given method, or nullptr on error.
///
/// Note: Evaluation of the given method must have been run prior to ROC
/// Note: Evaluation of the given method must have been run prior to ROC
/// generation through Factory::EvaluateAllMetods.
///
///
/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
/// and the others considered background. This is ok in binary classification
/// but in in multi class classification, the ROC surface is an N dimensional
Expand All @@ -925,7 +936,7 @@ TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, B
Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
return nullptr;
}

if ( ! this->HasMethod(datasetname, theMethodName) ) {
Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
return nullptr;
Expand Down Expand Up @@ -960,10 +971,10 @@ TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, B
////////////////////////////////////////////////////////////////////////////////
/// Generate a collection of graphs, for all methods for a given class. Suitable
/// for comparing method performance.
///
/// Argument iClass specifies the class to generate the ROC curve in a
///
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///
///
/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
/// and the others considered background. This is ok in binary classification
/// but in in multi class classification, the ROC surface is an N dimensional
Expand All @@ -977,10 +988,10 @@ TMultiGraph* TMVA::Factory::GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t i
////////////////////////////////////////////////////////////////////////////////
/// Generate a collection of graphs, for all methods for a given class. Suitable
/// for comparing method performance.
///
/// Argument iClass specifies the class to generate the ROC curve in a
///
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///
///
/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
/// and the others considered background. This is ok in binary classification
/// but in in multi class classification, the ROC surface is an N dimensional
Expand All @@ -999,14 +1010,14 @@ TMultiGraph* TMVA::Factory::GetROCCurveAsMultiGraph(TString datasetname, UInt_t

TString methodName = method->GetMethodName();
UInt_t nClasses = method->DataInfo().GetNClasses();

if ( this->fAnalysisType == Types::kMulticlass && iClass >= nClasses ) {
Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass, nClasses) << Endl;
continue;
}

TString className = method->DataInfo().GetClassInfo(iClass)->GetName();

TGraph *graph = this->GetROCCurve(datasetname, methodName, false, iClass);
graph->SetTitle(methodName);

Expand All @@ -1026,12 +1037,12 @@ TMultiGraph* TMVA::Factory::GetROCCurveAsMultiGraph(TString datasetname, UInt_t
}

////////////////////////////////////////////////////////////////////////////////
/// Draws ROC curves for all methods booked with the factory for a given class
/// Draws ROC curves for all methods booked with the factory for a given class
/// onto a canvas.
///
/// Argument iClass specifies the class to generate the ROC curve in a
///
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///
///
/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
/// and the others considered background. This is ok in binary classification
/// but in in multi class classification, the ROC surface is an N dimensional
Expand All @@ -1044,10 +1055,10 @@ TCanvas * TMVA::Factory::GetROCCurve(TMVA::DataLoader *loader, UInt_t iClass)

////////////////////////////////////////////////////////////////////////////////
/// Draws ROC curves for all methods booked with the factory for a given class.
///
/// Argument iClass specifies the class to generate the ROC curve in a
///
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///
///
/// NOTE: The ROC curve is 1 vs. all where the given class is considered signal
/// and the others considered background. This is ok in binary classification
/// but in in multi class classification, the ROC surface is an N dimensional
Expand Down Expand Up @@ -1206,9 +1217,9 @@ void TMVA::Factory::TrainAllMethods()
//ToDo, Do we need to fill the DataSetManager of MethodBoost here too?


TString fFileDir= m->DataInfo().GetName();
fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
m->SetWeightFileDir(fFileDir);
TString fileDir= m->DataInfo().GetName();
fileDir+="/"+gConfig().GetIONames().fWeightFileDir;
m->SetWeightFileDir(fileDir);
m->SetModelPersistence(fModelPersistence);
m->SetSilentFile(IsSilentFile());
m->SetAnalysisType(fAnalysisType);
Expand Down Expand Up @@ -2527,4 +2538,3 @@ TH1F* TMVA::Factory::GetImportance(const int nbits,std::vector<Double_t> importa
// vih1->Draw("B");
return vih1;
}

10 changes: 5 additions & 5 deletions tmva/tmva/src/MethodBase.cxx
Expand Up @@ -764,8 +764,8 @@ void TMVA::MethodBase::AddRegressionOutput(Types::ETreeType type)
regRes->Resize( nEvents );

// Drawing the progress bar every event was causing a huge slowdown in the evaluation time
// So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100
// So we set some parameters to draw the progress bar a total of totalProgressDraws, i.e. only draw every 1 in 100

Int_t totalProgressDraws = 100; // total number of times to update the progress bar
Int_t drawProgressEvery = 1; // draw every nth event such that we have a total of totalProgressDraws
if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
Expand Down Expand Up @@ -1985,7 +1985,7 @@ TDirectory* TMVA::MethodBase::BaseDir() const
sdir = methodDir->mkdir(defaultDir);
sdir->cd();
// write weight file name into target file
if (fModelPersistence) {
if (fModelPersistence) {
TObjString wfilePath( gSystem->WorkingDirectory() );
TObjString wfileName( GetWeightFileName() );
wfilePath.Write( "TrainingPath" );
Expand Down Expand Up @@ -2042,7 +2042,7 @@ TDirectory *TMVA::MethodBase::MethodBaseDir() const
void TMVA::MethodBase::SetWeightFileDir( TString fileDir )
{
fFileDir = fileDir;
gSystem->MakeDirectory( fFileDir );
gSystem->mkdir( fFileDir, kTRUE );
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -3202,7 +3202,7 @@ void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
GetMethodType() != Types::kHMatrix) {
fout << " Transform( iV, -1 );" << std::endl;
}

if(GetAnalysisType() == Types::kMulticlass) {
fout << " retval = GetMulticlassValues__( iV );" << std::endl;
} else {
Expand Down

0 comments on commit 48d5751

Please sign in to comment.