In [None]:
import time
import awkward as ak
import numpy as np
from coffea.nanoevents import NanoEventsFactory, BaseSchema, NanoAODSchema

fname_wobib = "/nfs_scratch/slomte/MuColl/runjobs10bib/singleHiggs/bib100/analysis/JetHistogramsHbb_sig_only_bib100_2mev_250ps_ak5_10kEvents.root"
fname_wbib = "/nfs_scratch/slomte/MuColl/runjobs10bib/singleHiggs/bib100/analysis/JetHistogramsHbb_sig_and_BIB_bib100_2mev_250ps_ak5_10kEvents.root"

eve_gen = NanoEventsFactory.from_root(
    fname_wobib,
    #metadata={"dataset":"gen level"},
    "JetHistogramGenJetTuple;1",
    schemaclass=BaseSchema,
).events()
eve_reco_wobib = NanoEventsFactory.from_root(
    fname_wobib,
    #metadata={"dataset":"w/o BIB"},
    "JetHistogramRecoJetTuple;1",
    schemaclass=BaseSchema,
).events()
eve_reco_wbib = NanoEventsFactory.from_root(
    fname_wbib,
    #metadata={"dataset":"w BIB"},
    "JetHistogramRecoJetTuple;1",
    schemaclass=BaseSchema,
).events()


In [None]:
from coffea import processor
from coffea.nanoevents.methods import candidate, vector, nanoaod
from coffea import hist
from functools import partial
import math
import numba
import pickle
import re

import matplotlib.pyplot as plt
import scipy.optimize as opt;


def selectJets(events):
    Jets = ak.zip(
    {
        "x": events.jmox,
        "y": events.jmoy,
        "z": events.jmoz,
        "t": events.jene,
        "cost": events.jcost,
    },
        with_name="LorentzVector",
        behavior=vector.behavior,
    )        
    sort_inds = ak.argsort(Jets.pt, ascending=False)
    Jets = Jets[sort_inds]
    return Jets

def findGenRecoMatch(i_gjets, i_rjets):
    f_dR_gj_rj = i_gjets.metric_table(i_rjets, axis=-1)
    o_gjets = i_gjets[ak.any(f_dR_gj_rj<0.25, axis=-1)]
    f_dR_new = o_gjets.metric_table(i_rjets, axis=-1)
    f_argmin = ak.argmin(f_dR_new, axis=-1)
    o_rjets = i_rjets[f_argmin]
    o_gjrj_pairs = ak.zip({'gJet': o_gjets, 'rJet': o_rjets})
    return o_gjrj_pairs



class MyProcessor(processor.ProcessorABC):
    def __init__(self):
        # Categories
        dataset_axis = hist.Cat("dataset", "")#gen, reco with, or reco without BIB
        evtlabel_axis = hist.Cat("evtlabel", "")#describing event, e.g 2jets in event
        jetIndex_axis = hist.Cat("jetIndex", "")#leading or subleading jet
        cut_axis = hist.Cat("cut", "")#Acceptance cut
        
        # Variables
        n_axis = hist.Bin("n", "N", 20, 0, 20)
        pt_axis = hist.Bin("pt", r"$p_{T}$ [GeV]", 30, 1.0, 300.)
        eta_axis = hist.Bin("eta", r"$\eta$", 50, -5.0, 5.0)
        phi_axis = hist.Bin("phi", r"$\phi$", 50, -5.0, 5.0)
        energy_axis = hist.Bin("energy", r"$E$ [GeV]", 50, 1.0, 500.)        
        mass_axis = hist.Bin("mass", r"mass [GeV]", 40, 1.0, 200.)
        cost_axis = hist.Bin("cost", r"$cos\theta$", 22, -1.1, 1.1)
        dr_axis = hist.Bin("dr", r"$\Delta R$", 50, 0., 5.0)
        
        
        ### Accumulator for holding histograms
        self._accumulator = processor.dict_accumulator(
            {
                "EventCount": processor.value_accumulator(int),
                "cutflow" : processor.defaultdict_accumulator(partial(processor.defaultdict_accumulator, int)),
                
                
                # jet kinematics
                "Jets_N" : hist.Hist("Events", dataset_axis, cut_axis, n_axis),
                "Jet_pt": hist.Hist("Events", dataset_axis, cut_axis, jetIndex_axis, pt_axis),
                "Jet_eta": hist.Hist("Events", dataset_axis, cut_axis, jetIndex_axis, eta_axis),
                "Jet_phi": hist.Hist("Events", dataset_axis, cut_axis, jetIndex_axis, phi_axis),
                "Jet_energy": hist.Hist("Events", dataset_axis, cut_axis, jetIndex_axis, energy_axis),
                "Jet_mass": hist.Hist("Events", dataset_axis, cut_axis, jetIndex_axis, mass_axis),
                "Jet_cosTheta": hist.Hist("Events", dataset_axis, cut_axis, jetIndex_axis, cost_axis),
                
                # di-jet kinematics
                "dijet_mass": hist.Hist("Events", dataset_axis, evtlabel_axis, mass_axis),
                "trijet_mass": hist.Hist("Events", dataset_axis, evtlabel_axis, mass_axis),
                "Inv_Mass": hist.Hist("Events", dataset_axis, evtlabel_axis, mass_axis),                
                "H_pt": hist.Hist("Events", dataset_axis, cut_axis, pt_axis),
                "H_eta": hist.Hist("Events", dataset_axis, cut_axis, eta_axis),
                "H_phi": hist.Hist("Events", dataset_axis, cut_axis, phi_axis),
                "H_energy": hist.Hist("Events", dataset_axis, cut_axis, energy_axis),
                
                "dR_j1j2": hist.Hist("Events", dataset_axis, cut_axis, dr_axis),
                
            }
        )

    @property
    def accumulator(self):
        return self._accumulator

    def process(self, events, datalabel):
        #dataset = events.metadata['dataset']
        output = self.accumulator.identity()

        #------Code starts here------#        
        #--- jet ---#
        jets = ak.zip(
            {
                "x": events.jmox,
                "y": events.jmoy,
                "z": events.jmoz,
                "t": events.jene,
                "cost": events.jcost,
            },
            with_name="LorentzVector",
            behavior=vector.behavior,
        )        
        sort_inds = ak.argsort(jets.pt, ascending=False)
        jets = jets[sort_inds]
        
        output["Jets_N"].fill(dataset=datalabel, n=ak.num(jets, axis=1), cut='nocut')
        output["Jets_N"].fill(dataset=datalabel, n=ak.num(jets[jets.pt>20], axis=1), cut='pt20')
        output["Jets_N"].fill(dataset=datalabel, n=ak.num(jets[(jets.pt>20)&(abs(jets.cost)<0.98)], axis=1), cut='pt20cost0p98')
        output["Jets_N"].fill(dataset=datalabel, n=ak.num(jets[(jets.pt>20)&(abs(jets.cost)<0.9)], axis=1), cut='pt20cost0p9')
        output["Jets_N"].fill(dataset=datalabel, n=ak.num(jets[(jets.pt>20)&(abs(jets.cost)<0.8)], axis=1), cut='pt20cost0p8')
        
        
        #Object selections (acceptance cuts)
        PtTheta_cut = (jets.pt>20) & (abs(jets.cost)<0.8)
        goodjets = jets[PtTheta_cut]
        #print(goodjets.pt)
        
        #good jet kinematics
        output["Jets_N"].fill(dataset=datalabel, n=ak.num(goodjets, axis=1), cut='pt20cost0p8')
        output["Jet_pt"].fill(dataset=datalabel, pt=ak.flatten(goodjets.pt), jetIndex='All', cut='pt20cost0p8')
        output["Jet_eta"].fill(dataset=datalabel, eta=ak.flatten(goodjets.eta), jetIndex='All', cut='pt20cost0p8')
        output["Jet_phi"].fill(dataset=datalabel, phi=ak.flatten(goodjets.phi), jetIndex='All', cut='pt20cost0p8')
        output["Jet_mass"].fill(dataset=datalabel, mass=ak.flatten(goodjets.mass), jetIndex='All', cut='pt20cost0p8')
        output["Jet_energy"].fill(dataset=datalabel, energy=ak.flatten(goodjets.energy), jetIndex='All', cut='pt20cost0p8')
        output["Jet_cosTheta"].fill(dataset=datalabel, cost=ak.flatten(goodjets.cost), jetIndex='All', cut='pt20cost0p8')
        
        method=1
        #plot kinematics of corrected reco jets
        PtTheta_cut_rj = (CorrRecoJet.pt>20) & (abs(CorrRecoJet.cost)<0.8)
        good_crjets = CorrRecoJet[PtTheta_cut_rj]
        output["Jets_N"].fill(dataset=datalabel, n=ak.num(good_crjets, axis=1), cut=f'crjets_pt20cost0p8_method{method}')
        output["Jet_pt"].fill(dataset=datalabel, pt=ak.flatten(good_crjets.pt), jetIndex='All', cut=f'crjets_pt20cost0p8_method{method}')
        output["Jet_eta"].fill(dataset=datalabel, eta=ak.flatten(good_crjets.eta), jetIndex='All', cut=f'crjets_pt20cost0p8_method{method}')
        output["Jet_phi"].fill(dataset=datalabel, phi=ak.flatten(good_crjets.phi), jetIndex='All', cut=f'crjets_pt20cost0p8_method{method}')
        output["Jet_mass"].fill(dataset=datalabel, mass=ak.flatten(good_crjets.mass), jetIndex='All', cut=f'crjets_pt20cost0p8_method{method}')
        output["Jet_energy"].fill(dataset=datalabel, energy=ak.flatten(good_crjets.energy), jetIndex='All', cut=f'crjets_pt20cost0p8_method{method}')
        output["Jet_cosTheta"].fill(dataset=datalabel, cost=ak.flatten(good_crjets.cost), jetIndex='All', cut=f'crjets_pt20cost0p8_method{method}')
        
        
        
        #------ Higgs reconstruction from dijet or trijet ------#
        '''
        ptcut=20
        costhetacut=0.8
        goodjets = jets[(jets.pt>ptcut) & (abs(jets.cost)<costhetacut)]
        jet1 = ak.mask(goodjets, ak.num(goodjets, axis=1)==2)[:,0]
        jet2 = ak.mask(goodjets, ak.num(goodjets, axis=1)==2)[:,1]
        dijet = jet1.add(jet2)
        jet1 = ak.mask(goodjets, ak.num(goodjets, axis=1)>2)[:,0]
        jet2 = ak.mask(goodjets, ak.num(goodjets, axis=1)>2)[:,1]
        jet3 = ak.mask(goodjets, ak.num(goodjets, axis=1)>2)[:,2]
        trijet = (jet1.add(jet2)).add(jet3)        
        dijet_mass = ak.fill_none(dijet.mass, 0)
        trijet_mass = ak.fill_none(trijet.mass, 0)
        dijet_pt = ak.fill_none(dijet.pt,0)
        trijet_pt = ak.fill_none(trijet.pt,0)
        dijet_energy = ak.fill_none(dijet.energy,0)
        trijet_energy = ak.fill_none(trijet.energy,0)
        dijet_eta = ak.fill_none(dijet.eta,1000)
        trijet_eta = ak.fill_none(trijet.eta,1000)
        dijet_phi = ak.fill_none(dijet.phi,1000)
        trijet_phi = ak.fill_none(trijet.phi,1000)
        
        # reco Higgs kinematics
        output['dijet_mass'].fill(dataset=datalabel, evtlabel='=2j evts', mass=dijet_mass)
        output['trijet_mass'].fill(dataset=datalabel, evtlabel='>2j evts', mass=trijet_mass)        
        output['Inv_Mass'].fill(dataset=datalabel, evtlabel='>=2j evts', mass=dijet_mass)
        output['Inv_Mass'].fill(dataset=datalabel, evtlabel='>=2j evts', mass=trijet_mass)
        # for >=2jet events
        output['H_pt'].fill(dataset=datalabel, pt=dijet_pt, cut='pt20cost0p8')
        output['H_pt'].fill(dataset=datalabel, pt=trijet_pt, cut='pt20cost0p8')
        output['H_eta'].fill(dataset=datalabel, eta=dijet_eta, cut='pt20cost0p8')
        output['H_eta'].fill(dataset=datalabel, eta=trijet_eta, cut='pt20cost0p8')
        output['H_phi'].fill(dataset=datalabel, phi=dijet_phi, cut='pt20cost0p8')
        output['H_phi'].fill(dataset=datalabel, phi=trijet_phi, cut='pt20cost0p8')
        output['H_energy'].fill(dataset=datalabel, energy=dijet_energy, cut='pt20cost0p8')
        output['H_energy'].fill(dataset=datalabel, energy=trijet_energy, cut='pt20cost0p8')
        
        # Higgs jet constituent kinematics (j1,j2,j3) 
        jet1 = ak.mask(goodjets, ak.num(goodjets, axis=1)>1)[:,0]
        jet2 = ak.mask(goodjets, ak.num(goodjets, axis=1)>1)[:,1]
        jet3 = ak.mask(goodjets, ak.num(goodjets, axis=1)>2)[:,2]
        output["Jet_pt"].fill(dataset=datalabel, pt=ak.fill_none(jet1.pt,-1), jetIndex='j1', cut='pt20cost0p8')
        output["Jet_pt"].fill(dataset=datalabel, pt=ak.fill_none(jet2.pt,-1), jetIndex='j2', cut='pt20cost0p8')
        output["Jet_pt"].fill(dataset=datalabel, pt=ak.fill_none(jet3.pt,-1), jetIndex='j3', cut='pt20cost0p8')
        output["Jet_eta"].fill(dataset=datalabel, eta=ak.fill_none(jet1.eta,1000), jetIndex='j1', cut='pt20cost0p8')
        output["Jet_eta"].fill(dataset=datalabel, eta=ak.fill_none(jet2.eta,1000), jetIndex='j2', cut='pt20cost0p8')
        output["Jet_eta"].fill(dataset=datalabel, eta=ak.fill_none(jet3.eta,1000), jetIndex='j3', cut='pt20cost0p8')
        output["Jet_phi"].fill(dataset=datalabel, phi=ak.fill_none(jet1.phi,1000), jetIndex='j1', cut='pt20cost0p8')
        output["Jet_phi"].fill(dataset=datalabel, phi=ak.fill_none(jet2.phi,1000), jetIndex='j2', cut='pt20cost0p8')
        output["Jet_phi"].fill(dataset=datalabel, phi=ak.fill_none(jet3.phi,1000), jetIndex='j3', cut='pt20cost0p8')
        output["Jet_mass"].fill(dataset=datalabel, mass=ak.fill_none(jet1.mass,-1), jetIndex='j1', cut='pt20cost0p8')
        output["Jet_mass"].fill(dataset=datalabel, mass=ak.fill_none(jet2.mass,-1), jetIndex='j2', cut='pt20cost0p8')
        output["Jet_mass"].fill(dataset=datalabel, mass=ak.fill_none(jet3.mass,-1), jetIndex='j3', cut='pt20cost0p8')
        output["Jet_energy"].fill(dataset=datalabel, energy=ak.fill_none(jet1.energy,-1), jetIndex='j1', cut='pt20cost0p8')
        output["Jet_energy"].fill(dataset=datalabel, energy=ak.fill_none(jet2.energy,-1), jetIndex='j2', cut='pt20cost0p8')
        output["Jet_energy"].fill(dataset=datalabel, energy=ak.fill_none(jet3.energy,-1), jetIndex='j3', cut='pt20cost0p8')
        
        # dR between (j1,j2)
        dr_j1j2 = jet1.delta_r(jet2)
        output["dR_j1j2"].fill(dataset=datalabel, dr=ak.fill_none(dr_j1j2,-1), cut='pt20cost0p8')
        '''
        
        return output
        

    # --------------------------------------------------------------------------------------------------------------------------------------------------------- #
    # --------------------------------------------------------------------------------------------------------------------------------------------------------- #
    # --------------------------------------------------------------------------------------------------------------------------------------------------------- #
    # --------------------------------------------------------------------------------------------------------------------------------------------------------- #
        
    def performance_Jet(self, events):        
        #output = self.accumulator.identity()
        
        # -------------- Reconstruction efficiency -------------#
        # If we want to find the efficiency to reconstructed jets in a given pT and Eta region, we can do the following.
        # However, if we want to find the efficiency of b-tagging algorithm, which calculates how many of real b-jets are reconstructed or tagged as b-jets. This is the effi of the algorithm to correctly tag a b-jet.
        # But the MC full sim framework may not have b-tagger. Need to recheck this. <-- 
        
        # find number of events with atleast one jet with in a pT or theta region at gen and reco level
        recojets = selectJets(eve_reco_wobib)
        genjets = selectJets(eve_gen)
        recojets = recojets[recojets.pt>10]
        genjets = genjets[genjets.pt>10]
        
        geta = ak.fill_none(abs(genjets.eta), -10000000)
        gene = ak.fill_none(genjets.energy, -10000000)
        reta = ak.fill_none(abs(recojets.eta), -10000000)
        
        E_bins = np.linspace(10, 200, num=20)
        Eta_bins = np.linspace(0., 2.5, num=11)
        PT_bins = np.linspace(10, 200, num=20)
        
        Njets_gen = np.ones(10)
        Njets_reco = np.ones(10)
        Reco_efficiency = np.ones(10)
        for j in range(0,10):
            getabin = (geta>=Eta_bins[j]) & (geta<Eta_bins[j+1])
            retabin = (reta>=Eta_bins[j]) & (reta<Eta_bins[j+1])              
            Njets_gen[j] = ak.sum(ak.num(genjets[getabin]))
            Njets_reco[j] = ak.sum(ak.num(recojets[retabin]))
            Reco_efficiency = Njets_reco/Njets_gen
        
        #print(Reco_efficiency)
        
        #Plot reco efficiency as func of eta:
        '''
        fig, ax = plt.subplots(1)
        plt.step(x=Eta_bins[:10], y=Reco_efficiency, where='post')
        plt.xlabel(r'Jet $|\eta|$', fontsize=14)
        plt.ylabel('Jet reconstruction efficiency', fontsize=14)
        plt.ylim(0,1)
        process = plt.text(1.0, 1.0, r"$H \rightarrow b \bar{b}$ (1.5TeV)", fontsize=12, horizontalalignment="right", verticalalignment="bottom", transform=ax.transAxes)
        MC = plt.text(0.0, 1.0, r"$\bf{Muon Collider}$ Preliminary", fontsize=12, horizontalalignment="left", verticalalignment="bottom", transform=ax.transAxes)
        Sidenote1 = plt.text(0.75, 0.67,r"$p_{T}>10$ GeV", fontsize=12, transform=ax.transAxes)
        plt.savefig('recoEffiEta_pt10.png')
        '''
        # --------------------------------------------------#
        
        # -------------- Jet momentum Correction -------------#
        #Why: To recover energy lost in reconstruction. 
        #How: (1) for every gen jet, find a reco matched jet with min_dR<0.5
        #     (2) In Energy - Eta regions, find average and std deviation of gen jets 
        
        # Step (1): 
        #The function 'findGenRecoMatch' finds gen and reco jet matches.
        #The inputs are genjets and recojets arrays, output is zipped array with 1st index 'gJet' and 2nd index 'rJet' which are LorentzVectors
        #The genjets which don't have a reco jet within dR<0.25, those jets are removed from the list. So the zipped pairs have same strucure. 
        
        genjets = selectJets(eve_gen)
        recojets = selectJets(eve_reco_wobib)
        gen_reco_pairs = findGenRecoMatch(genjets, recojets)
        
        gene = ak.fill_none(gen_reco_pairs.gJet.energy, -10000000)
        gpt = ak.fill_none(gen_reco_pairs.gJet.pt, -10000000)
        geta = ak.fill_none(abs(gen_reco_pairs.gJet.eta), -10000000)
        reta = ak.fill_none(abs(gen_reco_pairs.rJet.eta), -10000000)
        gjets = gen_reco_pairs.gJet
        rjets = gen_reco_pairs.rJet
        
        # Step (2):
        #for each Eta region, find mean and s.d. of E of genjets
        # E bins: [20,40,...200] --> 19 bins
        # |n| bins: [0,0.25,0.50,0.75,1.0,1.25,1.50,1.75,2.0,2.25,2.5] --> 10 bins
        
        etaregions = np.linspace(0., 2.5, num=11)
        eneregions = np.linspace(10, 200, num=20)
        ptregions = np.linspace(10, 200, num=20)
        gjets_meanE_reg = np.ones((10,19))
        rjets_meanE_reg = np.ones((10,19))
        gjets_meanPT_reg = np.ones((10,19))
        rjets_meanPT_reg = np.ones((10,19))
        
        for j in range(0,10):
            etareg = (geta>=etaregions[j]) & (geta<etaregions[j+1])
            for i in range(0,19):
                enereg = (gene>=eneregions[i]) & (gene<eneregions[i+1])
                ptreg = (gpt>=ptregions[i]) & (gpt<ptregions[i+1])
                
                # fills mean E and mean pT of genjets for each eta and energy region. Size: 10*19 
                gjets_meanE_reg[j][i] = ak.mean(gjets.energy[enereg & etareg], axis=None)
                rjets_meanE_reg[j][i] = ak.mean(rjets.energy[enereg & etareg], axis=None)
                gjets_meanPT_reg[j][i] = ak.mean(gjets.pt[ptreg & etareg], axis=None)
                rjets_meanPT_reg[j][i] = ak.mean(rjets.pt[ptreg & etareg], axis=None)
        
        
        # Response: 
        # R(Eavg, eta) = <E_reco>/<E_gen> [E_gen, eta_gen]; [] imply binning variables and < > imply average within bins of variables
        Response_E = rjets_meanE_reg/gjets_meanE_reg
        #Response_PT = rjets_meanPT_reg/gjets_meanPT_reg
        
        # Plot genjets average E vs recojets average E for various eta regions:
        '''
        plt.plot(etaregions[:10]+0.125, Response_E[:,9], label='100-110GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,10], label='110-120GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,11], label='120-130GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,12], label='130-140GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,13], label='140-150GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,14], label='150-160GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,15], label='160-170GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,16], label='170-180GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,17], label='180-190GeV')
        plt.plot(etaregions[:10]+0.125, Response_E[:,18], label='190-200GeV')
        plt.legend(loc=1)
        plt.ylim(0,1.2)
        plt.xlim(0,3.8)
        plt.xlabel(r'Jet $|\eta|$', fontsize=13)
        plt.ylabel('Response', fontsize=13)
        plt.title(r'R = $\frac{<E_{reco}>}{<E_{gen}>}[E_{gen}, \eta_{gen}]$', fontsize=15)
        plt.savefig('Response_2.png')
        '''
        
        # Step (3):
        # We fit a linear function to gen jet avg E vs reco jet avg E:
        def fitted_function(x, a, b, c):
             return a*(x**b) + c
        
        optimizedParameters_E = np.ones((10,3))
        optimizedParameters_PT = np.ones((10,3))
        op, cov, op1, cov1 = 0, 0, 0, 0
        for e in range(0,10):
            recoE_avg = rjets_meanE_reg[e,:]
            genE_avg = gjets_meanE_reg[e,:]
            op, cov = opt.curve_fit(fitted_function, recoE_avg, genE_avg, check_finite=False)        
            optimizedParameters_E[e] = op

            recoPT_avg = rjets_meanPT_reg[e,:]
            genPT_avg = gjets_meanPT_reg[e,:]
            op1, cov1 = opt.curve_fit(fitted_function, recoPT_avg, genPT_avg, check_finite=False)        
            optimizedParameters_PT[e] = op1
            
        #----> We access the fitted function using: fitted_function(x, *optimizedParameters[eta_region])
        # We can plot the fitted curve and original E of gen & reco jets
        '''
        for eta in range(0,10):
            fig, ax = plt.subplots(1)
            plt.plot(rjets_meanPT_reg[eta,:], gjets_meanPT_reg[eta,:], ".", label="Data")
            plt.plot(rjets_meanPT_reg[eta,:], fitted_function(rjets_meanPT_reg[eta,:], *optimizedParameters_PT[eta]), label="fit")
            plt.legend()
            ax.set_xlabel(r'$<p_{T}^{reco}>$ [GeV]', fontsize=13)
            ax.set_ylabel(r'$<p_{T}^{gen}>$ [GeV]', fontsize=13)
            process = plt.text(1.0, 1.0, r"$H \rightarrow b \bar{b}$ (1.5TeV)", fontsize=12, horizontalalignment="right", verticalalignment="bottom", transform=ax.transAxes)
            MC = plt.text(0.0, 1.0, r"$\bf{Muon Collider}$ Preliminary", fontsize=12, horizontalalignment="left", verticalalignment="bottom", transform=ax.transAxes)
            Sidenote1 = plt.text(0.70, 0.60,fr"${etaregions[eta]}<|\eta|<{etaregions[eta+1]}$", fontsize=12, transform=ax.transAxes)
            sidenote1 = plt.text(0.20, 0.67, r'$y = a x^{b} + c$', fontsize=12, transform=ax.transAxes)
            sidenote2 = plt.text(0.12, 0.57, f'[a={optimizedParameters_PT[eta,0]:.1f}, b={optimizedParameters_PT[eta,1]:.1f}, c={optimizedParameters_PT[eta,2]:.1f}]', fontsize=12, transform=ax.transAxes)
            #Sidenote2 = plt.text(0.20, 0.60,fr"y={optimizedParameters[0,0]:.1f} $x^{optimizedParameters[0,1]:.1f}$ {optimizedParameters[0,2]:.1f}", fontsize=12, transform=ax.transAxes)
            plt.savefig(f'transferfunc_etareg{eta}_PTfit.png')
        '''
        
        # Step (4):
        # Find scale factor for each reco jet       

        #Method1: scale factor is gen_jet PT / reco_jet PT for each jet pair.
        scale_factors_m1 = ak.where((reta>=0) & (reta<0.25), (ak.mask(gjets.energy, (reta>=0) & (reta<0.25)))/(ak.mask(rjets.energy, (reta>=0) & (reta<0.25))), 0)
        for j in range(1,9):
            etareg = (reta>=etaregions[j]) & (reta<etaregions[j+1])
            scale_factors_m1 = ak.where(etareg, (ak.mask(gjets.energy, etareg))/(ak.mask(rjets.energy, etareg)), scale_factors_m1)

        #Method2: scale factor is corr_reco_jet E / reco_jet E for various eta regions.
        scale_factors_m2 = ak.where((reta>=0) & (reta<0.25), (fitted_function(ak.mask(rjets.energy, (reta>=0) & (reta<0.25)), *optimizedParameters_E[0]))/(ak.mask(rjets.energy, (reta>=0) & (reta<0.25))), 0)
        for j in range(1,9):
            etareg = (reta>=etaregions[j]) & (reta<etaregions[j+1])
            scale_factors_m2 = ak.where(etareg, (fitted_function(ak.mask(rjets.energy, etareg), *optimizedParameters_E[j]))/(ak.mask(rjets.energy, etareg)), scale_factors_m2)

        #Method3: scale factor is corr_reco_jet PT / reco_jet PT for various eta regions.
        scale_factors_m3 = ak.where((reta>=0) & (reta<0.25), (fitted_function(ak.mask(rjets.pt, (reta>=0) & (reta<0.25)), *optimizedParameters_PT[0]))/(ak.mask(rjets.pt, (reta>=0) & (reta<0.25))), 0)
        for j in range(1,9):
            etareg = (reta>=etaregions[j]) & (reta<etaregions[j+1])
            scale_factors_m3 = ak.where(etareg, (fitted_function(ak.mask(rjets.pt, etareg), *optimizedParameters_PT[j]))/(ak.mask(rjets.pt, etareg)), scale_factors_m3)

            
        method = 1 #1 or 2 or 3
        if method==1:
            sf = scale_factors_m1
        elif method==2:
            sf = scale_factors_m2
        elif method==3:
            sf = scale_factors_m3
            
        # Correct the 4-momentum of reco jets by scale factors
        corr_reco_Jets = ak.zip(
            {
                "x": sf*rjets.x,
                "y": sf*rjets.y,
                "z": sf*rjets.z,
                "t": sf*rjets.t,
                "cost": rjets.cost,
            },
            with_name="LorentzVector",
            behavior=vector.behavior,
        )        

        
        # Calculate resolution: 
        g_pt = gjets.pt
        g_eta = gjets.eta
        g_ene = gjets.energy
        cr_pt = corr_reco_Jets.pt
        cr_eta = corr_reco_Jets.eta
        cr_ene = corr_reco_Jets.energy
        
        pt_resolution = abs(g_pt - cr_pt)/g_pt
        energy_resolution = abs(g_ene - cr_ene)/g_ene
        eta_resolution = (g_eta - cr_eta)/(g_eta)
        
        print(pt_resolution)
        print(energy_resolution)
        print(eta_resolution)
        
        '''
        #plot pt and eta resolution histograms
        fig, ax = plt.subplots(1)
        plt.hist(x=ak.flatten(eta_resolution), range=(-1,1), bins=50, histtype='step')
        plt.xlabel(r'$(\eta^{gen} - \eta^{reco})/\eta^{gen}$', fontsize=14)
        plt.ylabel('A.U', fontsize=14)
        process = plt.text(1.0, 1.0, r"$H \rightarrow b \bar{b}$ (1.5TeV)", fontsize=12, horizontalalignment="right", verticalalignment="bottom", transform=ax.transAxes)
        MC = plt.text(0.0, 1.0, r"$\bf{Muon Collider}$ Preliminary", fontsize=12, horizontalalignment="left", verticalalignment="bottom", transform=ax.transAxes)
        Title = plt.text(0.75, 0.60, f"Method{method} correction", fontsize=12, horizontalalignment="center", verticalalignment="bottom", transform=ax.transAxes)
        #Sidenote1 = plt.text(0.75, 0.67,r"$p_{T}>$ GeV", fontsize=12, transform=ax.transAxes)
        plt.savefig(f'etaResolution_method{method}.png')

        fig, ax = plt.subplots(1)
        plt.hist(x=ak.flatten(pt_resolution), range=(0,2), bins=50, histtype='step')
        plt.xlabel(r'$|p_{T}^{gen} - p_{T}^{reco}|/p_{T}^{gen}$', fontsize=14)
        plt.ylabel('A.U', fontsize=14)
        process = plt.text(1.0, 1.0, r"$H \rightarrow b \bar{b}$ (1.5TeV)", fontsize=12, horizontalalignment="right", verticalalignment="bottom", transform=ax.transAxes)
        MC = plt.text(0.0, 1.0, r"$\bf{Muon Collider}$ Preliminary", fontsize=12, horizontalalignment="left", verticalalignment="bottom", transform=ax.transAxes)
        Title = plt.text(0.75, 0.60, f"Method{method} correction", fontsize=12, horizontalalignment="center", verticalalignment="bottom", transform=ax.transAxes)
        #Sidenote1 = plt.text(0.75, 0.67,r"$p_{T}>$ GeV", fontsize=12, transform=ax.transAxes)
        plt.savefig(f'ptResolution_method{method}.png')
        '''
        
        #plot pt_resolution as a func of pt, with avg value in each pt bin
        pt_bins = np.linspace(10, 200, num=20)
        pt_Reso = np.ones(19)
        for j in range(0,19):
            ptreg = (g_pt>=pt_bins[j]) & (g_pt<pt_bins[j+1])
            gpt = ak.mask(g_pt, ptreg)
            crpt = ak.mask(cr_pt, ptreg)
            pt_reso = ak.flatten(abs(gpt - crpt)/gpt)
            pt_Reso[j] = ak.mean(pt_reso)
        print(pt_Reso)
        '''
        fig, ax = plt.subplots(1)
        plt.step(x=pt_bins[:19], y=pt_Reso, where='post')
        plt.xlabel(r'$p_{T}^{gen}$ [GeV]', fontsize=14)
        plt.ylabel(r'jet $p_{T}$ resolution', fontsize=14)
        #plt.ylim(0,0.55)
        process = plt.text(1.0, 1.0, r"$H \rightarrow b \bar{b}$ (1.5TeV)", fontsize=12, horizontalalignment="right", verticalalignment="bottom", transform=ax.transAxes)
        MC = plt.text(0.0, 1.0, r"$\bf{Muon Collider}$ Preliminary", fontsize=12, horizontalalignment="left", verticalalignment="bottom", transform=ax.transAxes)
        Title = plt.text(0.75, 0.60, f"Method{method} correction", fontsize=12, horizontalalignment="center", verticalalignment="bottom", transform=ax.transAxes)
        plt.savefig(f'ptResolution_vsPT_method{method}.png')
        '''        
                
        return gjets, rjets, corr_reco_Jets
        

    def postprocess(self, accumulator):
        pass
    
p = MyProcessor()
GenJet, RecoJet, CorrRecoJet = p.performance_Jet(eve_gen)

print('GenJet: ', GenJet.pt, GenJet.eta)
print('RecoJet: ', RecoJet.pt, RecoJet.eta)
print('CorrRecoJet: ', CorrRecoJet.pt, CorrRecoJet.eta)

output_gen = p.process(eve_gen, 'truth-level')
output_reco_sig = p.process(eve_reco_wobib, 'reco w/o BIB')
output_reco_sigbib = p.process(eve_reco_wbib, 'reco w/ BIB')