In [1]:
import os
import h5py
import shutil
import itertools

import numpy as np

from tqdm import tqdm

In [2]:
def all_pairs(lst):
    if len(lst) < 2:
        yield []
        return
    if len(lst) % 2 == 1:
        # Handle odd length list
        for i in range(len(lst)):
            for result in all_pairs(lst[:i] + lst[i+1:]):
                yield result
    else:
        a = lst[0]
        for i in range(1, len(lst)):
            pair = (a,lst[i])
            for rest in all_pairs(lst[1:i]+lst[i+1:]):
                yield [pair] + rest

                
def Mjets(jets):
    # jets: 一個形狀為 (n, 4) 的 NumPy 陣列，其中 n 是噴射數量，每個噴射有四個屬性（pt, eta, phi, m）

    pt, eta, phi, m = jets.T  # 將噴射屬性分解為單獨的陣列

    px = pt * np.cos(phi)
    py = pt * np.sin(phi)
    pz = pt * np.sinh(eta)
    e = np.sqrt(m*m + px*px + py*py + pz*pz)

    return np.sqrt(e.sum()**2 - px.sum()**2 - py.sum()**2 - pz.sum()**2)


def PxPyPzE(jets):
    # jets: 一個形狀為 (n, 4) 的 NumPy 陣列，其中 n 是噴射數量，每個噴射有四個屬性（pt, eta, phi, m）
    pt, eta, phi, m = jets.T

    px = pt * np.cos(phi)
    py = pt * np.sin(phi)
    pz = pt * np.sinh(eta)
    e = np.sqrt(m*m + px*px + py*py + pz*pz)

    return px.sum(), py.sum(), pz.sum(), e.sum()


def PtEtaPhiM(px, py, pz, e):

    P = np.sqrt(px**2 + py**2 + pz**2)
    pt = np.sqrt(px**2 + py**2)
    eta = 1/2 * np.log((P + pz)/(P - pz))
    phi = np.arctan(py/px)
    m = np.sqrt(e**2 - px**2 - py**2 - pz**2)

    return pt, eta, phi, m

In [3]:
def chi2_triHiggs(m1, m2, m3):
    mh = 125.0
    return (m1 - mh)**2 + (m2 - mh)**2 + (m3 - mh)**2


def abs_triHiggs(m1, m2, m3):
    return abs(m1 - 120) + abs(m2 - 115) + abs(m3 - 110)

In [4]:
def perform_jet_pairing(file_path, output_path, use_btag=False, pairing_method=chi2_triHiggs):

    shutil.copy(file_path, output_path)
    with h5py.File(file_path, 'r') as f, h5py.File(output_path, 'a') as f_out:

        nevent = f['INPUTS/Source/pt'].shape[0]

        for event in tqdm(range(nevent)):

            nj = f['INPUTS/Source/MASK'][event].sum()
            pt = f['INPUTS/Source/pt'][event]
            eta = f['INPUTS/Source/eta'][event]
            phi = f['INPUTS/Source/phi'][event]
            mass = f['INPUTS/Source/mass'][event]
            btag = f['INPUTS/Source/btag'][event]

            chisq = -1 
            pair = []

            jets_index = np.where(btag)[0][0:6] if use_btag else range(nj)

            for combination in itertools.combinations(jets_index, 6):
                for (i1,i2), (i3,i4), (i5,i6) in all_pairs(combination):       
                    jets = np.array([[pt[i], eta[i], phi[i], mass[i]] for i in [i1, i2, i3, i4, i5, i6]])
            
                    pt1, _, _, mh1 = PtEtaPhiM(*PxPyPzE(jets[[0, 1]]))
                    pt2, _, _, mh2 = PtEtaPhiM(*PxPyPzE(jets[[2, 3]]))
                    pt3, _, _, mh3 = PtEtaPhiM(*PxPyPzE(jets[[4, 5]]))

                    pt_mh_pairs = sorted(zip([pt1, pt2, pt3], [mh1, mh2, mh3], [(i1, i2), (i3, i4), (i5, i6)]))
                    pt_sorted, mh_sorted, pair_sorted = zip(*pt_mh_pairs)

                    mh1, mh2, mh3 = mh_sorted[::-1]
                    tem = pairing_method(mh1, mh2, mh3)

                    if chisq < 0 or tem < chisq:
                        chisq = tem
                        pair = [jet for pair in pair_sorted[::-1] for jet in pair]

            f_out['TARGETS/h1/b1'][event] = pair[0]
            f_out['TARGETS/h1/b2'][event] = pair[1]
            f_out['TARGETS/h2/b1'][event] = pair[2]
            f_out['TARGETS/h2/b2'][event] = pair[3]
            f_out['TARGETS/h3/b1'][event] = pair[4]
            f_out['TARGETS/h3/b2'][event] = pair[5]

In [5]:
file_path = '../SPANet2/data/triHiggs/gghhh_6b_PT40_test.h5'
output_path = '../SPANet2/data/triHiggs/gghhh_6b_PT40_test-chi2_pairing.h5'

perform_jet_pairing(file_path, output_path, use_btag=True, pairing_method=chi2_triHiggs)

100%|██████████| 40000/40000 [03:10<00:00, 209.98it/s]


In [5]:
file_path = '../SPANet2/data/triHiggs/gghhh_6b_PT40_test.h5'
output_path = '../SPANet2/data/triHiggs/gghhh_6b_PT40_test-abs_pairing.h5'

perform_jet_pairing(file_path, output_path, use_btag=True, pairing_method=abs_triHiggs)

100%|██████████| 40000/40000 [03:18<00:00, 201.83it/s]


In [5]:
file_path = './Sample/SPANet/pp6b_6b.h5'
output_path = './Sample/SPANet/pp6b_6b-chi2_pairing.h5'

perform_jet_pairing(file_path, output_path, use_btag=True, pairing_method=chi2_triHiggs)

100%|██████████| 28755/28755 [02:36<00:00, 184.14it/s]


In [None]:
file_path = './Sample/SPANet/pp6b_6b.h5'
output_path = './Sample/SPANet/pp6b_6b-abs_pairing.h5'

perform_jet_pairing(file_path, output_path, use_btag=True, pairing_method=abs_triHiggs)

100%|██████████| 28755/28755 [02:35<00:00, 185.25it/s]
