In [None]:
import os
from pathlib import Path
from time import time
from typing import Optional, Tuple
from dataclasses import dataclass

import ml
import ROOT
from distributed import Client, LocalCluster, SSHCluster, get_worker
from ml import (
    define_features,
    infer_output_ml_features,
    ml_features_config,
)
from plotting import save_ml_plots, save_plots
from utils import (
    AGCInput,
    AGCResult,
    postprocess_results,
    retrieve_inputs,
    save_histos,
)

In [None]:
# NOTE: the client's URL need to be adapted to the actual configuration
from dask.distributed import Client

client = Client("tls://10.100.218.100:30448")
client

In [None]:
# Using https://atlas-groupdata.web.cern.ch/atlas-groupdata/dev/AnalysisTop/TopDataPreparation/XSection-MC15-13TeV.data
# as a reference. Values are in pb.
XSEC_INFO = {
    "ttbar": 396.87 + 332.97,  # nonallhad + allhad, keep same x-sec for all
    "single_top_s_chan": 2.0268 + 1.2676,
    "single_top_t_chan": (36.993 + 22.175) / 0.252,  # scale from lepton filter to inclusive
    "single_top_tW": 37.936 + 37.906,
    "wjets": 61457 * 0.252,  # e/mu+nu final states
}

In [None]:
@dataclass
class Args:
    n_max_files_per_sample = 1
    data_cache = None # shouldn't be used
    remote_data_prefix = 'root://eospublic.cern.ch//eos/root-eos/AGC'
    output = "histograms.root"
    inference = True
    scheduler = "dask-htcondor-swan"
    ncores = None # shouldn't be used
    npartitions = 2
    hosts = None # shouldn't be used
    verbose = False

In [None]:
def create_dask_client(scheduler: str, ncores: int, hosts: str) -> Client:
    return client

In [None]:
def make_rdf(
    files: list[str], client: Optional[Client], npartitions: Optional[int]
) -> ROOT.RDataFrame:
    """Construct and return a dataframe or, if a dask client is present, a distributed dataframe."""
    if client is not None:
        d = ROOT.RDF.Experimental.Distributed.Dask.RDataFrame(
            "Events", files, daskclient=client, npartitions=npartitions
        )
        d._headnode.backend.distribute_unique_paths(
            [
                "helpers.h",
                "ml_helpers.cpp",
                "ml.py",
            ]
        )
        return d

    return ROOT.RDataFrame("Events", files)

In [None]:
def define_trijet_mass(df: ROOT.RDataFrame) -> ROOT.RDataFrame:
    """Add the trijet_mass observable to the dataframe after applying the appropriate selections."""

    # First, select events with at least 2 b-tagged jets
    df = df.Filter("Sum(Jet_btagCSVV2_cut > 0.5) > 1")

    # Build four-momentum vectors for each jet
    df = df.Define("Jet_p4", "ConstructP4(Jet_pt_cut, Jet_eta_cut, Jet_phi_cut, Jet_mass_cut)")

    # Build trijet combinations
    df = df.Define("Trijet_idx", "Combinations(Jet_pt_cut, 3)")

    # Trijet_btag is a helpful array mask indicating whether or not the maximum btag value in Trijet is larger than the 0.5 threshold
    df = df.Define(
        "Trijet_btag",
        """
            auto J1_btagCSVV2 = Take(Jet_btagCSVV2_cut, Trijet_idx[0]);
            auto J2_btagCSVV2 = Take(Jet_btagCSVV2_cut, Trijet_idx[1]);
            auto J3_btagCSVV2 = Take(Jet_btagCSVV2_cut, Trijet_idx[2]);
            return J1_btagCSVV2 > 0.5 || J2_btagCSVV2 > 0.5 || J3_btagCSVV2 > 0.5;
            """,
    )

    # Assign four-momentums to each trijet combination
    df = df.Define(
        "Trijet_p4",
        """
        auto J1 = Take(Jet_p4, Trijet_idx[0]);
        auto J2 = Take(Jet_p4, Trijet_idx[1]);
        auto J3 = Take(Jet_p4, Trijet_idx[2]);
        return (J1+J2+J3)[Trijet_btag];
        """,
    )

    # Get trijet transverse momentum values from four-momentum vectors
    df = df.Define(
        "Trijet_pt",
        "return Map(Trijet_p4, [](const ROOT::Math::PxPyPzMVector &v) { return v.Pt(); })",
    )

    # Evaluate mass of trijet with maximum pt and btag higher than threshold
    df = df.Define("Trijet_mass", "Trijet_p4[ArgMax(Trijet_pt)].M()")

    return df

In [None]:
def book_histos(
    df: ROOT.RDataFrame, process: str, variation: str, nevents: int, inference=False
) -> Tuple[list[AGCResult], list[AGCResult]]:
    """Return the pair of lists of RDataFrame results pertaining to the given process and variation.
    The first list contains histograms of reconstructed HT and trijet masses.
    The second contains ML inference outputs"""
    # Calculate normalization for MC
    x_sec = XSEC_INFO[process]
    lumi = 3378  # /pb
    xsec_weight = x_sec * lumi / nevents
    df = df.Define("Weights", str(xsec_weight))  # default weights

    if variation == "nominal":
        # Jet_pt variations definition
        # pt_scale_up() and pt_res_up(jet_pt) return scaling factors applying to jet_pt
        # pt_scale_up() - jet energy scaly systematic
        # pt_res_up(jet_pt) - jet resolution systematic
        df = df.Vary(
            "Jet_pt",
            "ROOT::RVec<ROOT::RVecF>{Jet_pt*pt_scale_up(), Jet_pt*jet_pt_resolution(Jet_pt.size())}",
            ["pt_scale_up", "pt_res_up"],
        )

        if process == "wjets":
            # Flat weight variation definition
            df = df.Vary(
                "Weights",
                "Weights*flat_variation()",
                [f"scale_var_{direction}" for direction in ["up", "down"]],
            )

    # Event selection - the core part of the algorithm applied for both regions
    # Selecting events containing at least one lepton and four jets with pT > 25 GeV
    # Applying requirement at least one of them must be b-tagged jet (see details in the specification)
    df = (
        df.Define(
            "Electron_mask",
            "Electron_pt > 30 && abs(Electron_eta) < 2.1 && Electron_sip3d < 4 && Electron_cutBased == 4",
        )
        .Define(
            "Muon_mask",
            "Muon_pt > 30 && abs(Muon_eta) < 2.1 && Muon_sip3d < 4 && Muon_tightId && Muon_pfRelIso04_all < 0.15",
        )
        .Filter("Sum(Electron_mask) + Sum(Muon_mask) == 1")
        .Define("Jet_mask", "Jet_pt > 30 && abs(Jet_eta) < 2.4 && Jet_jetId == 6")
        .Filter("Sum(Jet_mask) >= 4")
    )

    # create columns for "good" jets
    df = (
        df.Define("Jet_pt_cut", "Jet_pt[Jet_mask]")
        .Define("Jet_btagCSVV2_cut", "Jet_btagCSVV2[Jet_mask]")
        .Define("Jet_eta_cut", "Jet_eta[Jet_mask]")
        .Define("Jet_phi_cut", "Jet_phi[Jet_mask]")
        .Define("Jet_mass_cut", "Jet_mass[Jet_mask]")
    )

    # b-tagging variations for nominal samples
    if variation == "nominal":
        df = df.Vary(
            "Weights",
            "ROOT::RVecD{Weights*btag_weight_variation(Jet_pt_cut)}",
            [
                f"{weight_name}_{direction}"
                for weight_name in [f"btag_var_{i}" for i in range(4)]
                for direction in ["up", "down"]
            ],
        )

    # Define HT observable for the 4j1b region
    # Only one b-tagged region required
    # The observable is the total transvesre momentum
    # fmt: off
    df4j1b = df.Filter("Sum(Jet_btagCSVV2_cut > 0.5) == 1")\
               .Define("HT", "Sum(Jet_pt_cut)")
    # fmt: on

    # Define trijet_mass observable for the 4j2b region (this one is more complicated)
    df4j2b = define_trijet_mass(df)

    # Book histograms and, if needed, their systematic variations
    results = []
    for df, observable, region in zip([df4j1b, df4j2b], ["HT", "Trijet_mass"], ["4j1b", "4j2b"]):
        histo_model = ROOT.RDF.TH1DModel(
            name=f"{region}_{process}_{variation}", title=process, nbinsx=25, xlow=50, xup=550
        )
        nominal_histo = df.Histo1D(histo_model, observable, "Weights")

        if variation == "nominal":
            results.append(AGCResult(nominal_histo, region, process, variation, nominal_histo, should_vary=True))
        else:
            results.append(AGCResult(nominal_histo, region, process, variation, nominal_histo, should_vary=False))
        print(f"Booked histogram {histo_model.fName}")

    ml_results: list[AGCResult] = []

    if not inference:
        return (results, ml_results)

    df4j2b = define_features(df4j2b)
    df4j2b = infer_output_ml_features(df4j2b)

    # Book histograms and, if needed, their systematic variations
    for i, feature in enumerate(ml_features_config):
        histo_model = ROOT.RDF.TH1DModel(
            name=f"{feature.name}_{process}_{variation}",
            title=feature.title,
            nbinsx=feature.binning[0],
            xlow=feature.binning[1],
            xup=feature.binning[2],
        )

        nominal_histo = df4j2b.Histo1D(histo_model, f"results{i}", "Weights")

        if variation == "nominal":
            ml_results.append(
                AGCResult(nominal_histo, feature.name, process, variation, nominal_histo, should_vary=True)
            )
        else:
            ml_results.append(
                AGCResult(nominal_histo, feature.name, process, variation, nominal_histo, should_vary=False)
            )
        print(f"Booked histogram {histo_model.fName}")

    # Return the booked results
    # Note that no event loop has run yet at this point (RDataFrame is lazy)
    return (results, ml_results)

In [None]:
def compile_macro_wrapper(library_path: str):
    ROOT.gInterpreter.Declare(
    '''
    #ifndef R__COMPILE_MACRO_WRAPPER
    #define R__COMPILE_MACRO_WRAPPER
    int CompileMacroWrapper(const std::string &library_path)
    {
        R__LOCKGUARD(gInterpreterMutex);
        return gSystem->CompileMacro(library_path.c_str(), "kO");
    }
    #endif // R__COMPILE_MACRO_WRAPPER
    ''')

    if ROOT.CompileMacroWrapper(library_path) != 1:
        raise RuntimeError("Failure in TSystem::CompileMacro!")

In [None]:
def load_cpp():
    try:
        this_worker = get_worker()
    except ValueError:
        print("Not on a worker")
        return

    if not hasattr(this_worker, "is_library_loaded"):
        print("Compiling the macro.")
        library_source = "helpers.h"
        local_dir = get_worker().local_directory
        library_path = os.path.join(local_dir, library_source)
        compile_macro_wrapper(library_path)
        this_worker.is_library_loaded = True
    else:
        print("Didn't try to compile the macro.")

In [None]:
def main() -> None:
    program_start = time()
    args = Args()

    # Do not add histograms to TDirectories automatically: we'll do it ourselves as needed.
    ROOT.TH1.AddDirectory(False)
    # Disable interactive graphics: avoids canvases flashing on screen before we save them to file
    ROOT.gROOT.SetBatch(True)

    if args.verbose:
        # Set higher RDF verbosity for the rest of the program.
        # To only change the verbosity in a given scope, use ROOT.Experimental.RLogScopedVerbosity.
        ROOT.Detail.RDF.RDFLogChannel.SetVerbosity(ROOT.Experimental.ELogLevel.kInfo)

    if args.scheduler == "mt":
        # Setup for local, multi-thread RDataFrame
        ROOT.EnableImplicitMT(args.ncores)
        print(f"Number of threads: {ROOT.GetThreadPoolSize()}")
        client = None
        load_cpp()
        if args.inference:
            ml.load_cpp("./fastforest")

        run_graphs = ROOT.RDF.RunGraphs
    else:
        # Setup for distributed RDataFrame
        client = create_dask_client(args.scheduler, args.ncores, args.hosts)
        if args.inference:
            def load_all(fastforest_path):
                load_cpp()
                ml.load_cpp(fastforest_path)

            fastforest_path="/eos/user/e/eguiraud/SWAN_projects/analysis-grand-challenge-root/analyses/cms-open-data-ttbar/fastforest"
            ROOT.RDF.Experimental.Distributed.initialize(load_all, fastforest_path)
        else:
            ROOT.RDF.Experimental.Distributed.initialize(load_cpp)
        run_graphs = ROOT.RDF.Experimental.Distributed.RunGraphs

    # Book RDataFrame results
    inputs: list[AGCInput] = retrieve_inputs(
        args.n_max_files_per_sample, args.remote_data_prefix, args.data_cache
    )
    results: list[AGCResult] = []
    ml_results: list[AGCResult] = []

    for input in inputs:
        df = make_rdf(input.paths, client, args.npartitions)
        hist_list, ml_hist_list = book_histos(
            df, input.process, input.variation, input.nevents, inference=args.inference
        )
        results += hist_list
        ml_results += ml_hist_list

    # Select the right VariationsFor function depending on RDF or DistRDF
    if args.scheduler == "mt":
        variationsfor_func = ROOT.RDF.Experimental.VariationsFor
    else:
        variationsfor_func = ROOT.RDF.Experimental.Distributed.VariationsFor
    for r in results + ml_results:
        if r.should_vary:
            r.histo = variationsfor_func(r.histo)

    print(f"Building the computation graphs took {time() - program_start:.2f} seconds")

    # FIXME remove this debug workaround
    print("TEST RUN START")
    print(results[0].nominal_histo.GetEntries())
    print("TEST RUN END")

    # Run the event loops for all processes and variations here
    run_graphs_start = time()
    run_graphs([r.nominal_histo for r in results + ml_results])

    print(f"Executing the computation graphs took {time() - run_graphs_start:.2f} seconds")

    results = postprocess_results(results)
    save_plots(results)
    save_histos([r.histo for r in results], output_fname=args.output)
    print(f"Result histograms saved in file {args.output}")

    if args.inference:
        ml_results = postprocess_results(ml_results)
        save_ml_plots(ml_results)
        output_fname = args.output.split(".root")[0] + "_ml_inference.root"
        save_histos([r.histo for r in ml_results], output_fname=output_fname)
        print(f"Result histograms from ML inference step saved in file {output_fname}")

    # FIXME this was moved down here because it looks like postprocess_results still needs the client,
    # but it might be a side-effect of errors happening in the event loop
    if client is not None:
        client.close()

In [None]:
main()