<img src="../tmva_logo.svg" height="50%" width="50%">

# TMVA  Cross Validation Example 

### Generate Data

We define the function to generate data

In [1]:
TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
{
   TRandom3 rng(seed);
   Float_t x = 0;
   Float_t y = 0;
   UInt_t eventID = 0;

   TTree *data = new TTree();
   data->Branch("x", &x, "x/F");
   data->Branch("y", &y, "y/F");
   data->Branch("eventID", &eventID, "eventID/I");

   for (Int_t n = 0; n < nPoints; ++n) {
      x = rng.Gaus(offset, scale);
      y = rng.Gaus(offset, scale);

      // For our simple example it is enough that the id's are uniformly
      // distributed and independent of the data.
      ++eventID;

      data->Fill();
   }

   // Important: Disconnects the tree from the memory locations of x and y.
   data->ResetBranchAddresses();
   return data;
}

### Define Output File

We declare the file for output

In [2]:
TMVA::Tools::Instance();

auto outputFile = TFile::Open("CV_Output.root", "RECREATE");

## TMVA Factory

Start by creating the Factory class. We can use the factory to choose the methods whose performance you'd like to investigate. 

The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass

 - The first argument is the base of the name of all the output
weightfiles in the directory weight/ that will be created with the 
method parameters 

 - The second argument is the output file for the training results
  
 - The third argument is a string option defining some general configuration for the TMVA session. For example all TMVA output can be suppressed by removing the "!" (not) in front of the "Silent" argument in the option string


In [3]:
TMVA::Factory factory("TMVAClassification", outputFile,
                      "!V:ROC:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" ); 

## DataLoader

The next step is to declare the DataLoader class which provides the interface from TMVA to the input data 

### Define input variables

Through the DataLoader we define the input variables that will be used for the MVA training.

In [4]:
TMVA::DataLoader * loader = new TMVA::DataLoader("dataset");

loader->AddVariable("x", 'F');
loader->AddVariable("y", 'F');

## Setup Dataset(s)

Define input data file and signal and background trees

In [5]:
// Generate signal and background data
TTree *tsignal = genTree(1000, 1.0, 1.0, 100);
TTree *tbackground = genTree(1000, -1.0, 1.0, 101);

// Register this data in the dataloader
loader->AddSignalTree(tsignal, 1.0);
loader->AddBackgroundTree(tbackground, 1.0);   

// Tell the factory how to use the training and testing events
//
// If no numbers of events are given, half of the events in the tree are used 
// for training, and the other half for testing:
//    loader->PrepareTrainingAndTestTree( mycut, "SplitMode=random:!V" );
// To also specify the number of testing events, use:
//    loader->PrepareTrainingAndTestTree( mycut,
//                                         "NSigTrain=3000:NBkgTrain=3000:NSigTest=3000:NBkgTest=3000:SplitMode=Random:!V" );

loader->PrepareTrainingAndTestTree("",
        "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V"); 

DataSetInfo              : [dataset] : Added class "Signal"
                         : Add Tree  of type Signal with 1000 events
DataSetInfo              : [dataset] : Added class "Background"
                         : Add Tree  of type Background with 1000 events
                         : Dataset[dataset] : Class index : 0  name : Signal
                         : Dataset[dataset] : Class index : 1  name : Background


# Run Cross Validation

### Format

 - The first argument is the method to be used i.e classfication, regression etc

 - The second argument is the data loader object
 
 - The third argument is the output file object
  
 - The fourth argument is a string option defining the options for the cross validation.


In [6]:
TString cvOptions = "!V:!Silent:ModelPersistence:AnalysisType=Classification:SplitType=RandomStratified:NumFolds=5";
                    ":SplitExpr=""";

auto cv = new TMVA::CrossValidation("TMVACrossValidation",loader,outputFile,cvOptions);

# Booking Methods


We Book here the different MVA method we want to use. 
We specify the method using the appropriate enumeration, defined in *TMVA::Types*.
See the file *TMVA/Types.h* for all possible MVA methods available. 
In addition, we specify via an option string all the method parameters. For all possible options, default parameter values, see the corresponding documentation in the TMVA Users Guide. 

Note that with the booking one can also specify individual variable tranformations to be done before using the method.
For example *VarTransform=Decorrelate* will decorrelate the inputs.  

In [7]:
//cv->BookMethod(TMVA::Types::kBDT, "BDT",
//        "NTrees=10:MinNodeSize=2.5%:MaxDepth=2:nCuts=20");

cv->BookMethod(TMVA::Types::kFisher, "Fisher",
                 "!H:!V:Fisher:VarTransform=None");

### Perform the  Cross Validation: Train/Test  the booked methods

In [8]:
// Run cross-validation
cv->Evaluate();

                         : Evaluate method: Fisher
<HEADER> Factory                  : Booking method: Fisher_fold1
                         : 
<HEADER> Fisher_fold1             : Results for Fisher coefficients:
                         : -----------------------
                         : Variable:  Coefficient:
                         : -----------------------
                         :        x:       +0.478
                         :        y:       +0.437
                         : (offset):       +0.007
                         : -----------------------
                         : Elapsed time for training with 1600 events: 0.00155 sec         
<HEADER> Fisher_fold1             : [dataset] : Evaluation of Fisher_fold1 on training sample (1600 events)
                         : Elapsed time for evaluation of 1600 events: 0.000658 sec       
                         : Creating xml weight file: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
                         : C

<HEADER> Factory                  : Thank you for using TMVA!
                         : For citation information, please visit: http://tmva.sf.net/citeTMVA.html
<HEADER> Factory                  : Booking method: Fisher
                         : 
                         : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
                         : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
                         : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
                         : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
                         : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold3.weights.xml
                         : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold3.weights.xml
                         : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold4.weig

## Cross Validation Result

In [9]:
TMVA::CrossValidationResult & result = (TMVA::CrossValidationResult &) cv->GetResults()[0];

result.Print();


<HEADER> CrossValidation          :  ==== Results ====
                         : Fold  0 ROC-Int : 0.9749
                         : Fold  1 ROC-Int : 0.9711
                         : Fold  2 ROC-Int : 0.9766
                         : Fold  3 ROC-Int : 0.9632
                         : Fold  4 ROC-Int : 0.9706
                         : ------------------------
                         : Average ROC-Int : 0.9713
                         : Std-Dev ROC-Int : 0.0052


## Plot ROC Curves
We enable JavaScript visualisation for the plots

In [10]:
%jsroot on

In [11]:
result.Draw();

## Plot Average ROC Curve

In [12]:
result.DrawAvgROCCurve("CrossValidation Avg ROC Curve",kFALSE);

## Plot ROC Curves and the Average ROC Curve

In [13]:
result.DrawAvgROCCurve("CrossValidation ROC Curves and Avg ROC Curve",kTRUE);

####  Close outputfile to save all output information (evaluation result of methods)

In [14]:
outputFile->Close();