In [3]:
import onnxruntime   
import h5py
import numpy as np
import numba
from typing import List
from numba import njit


In [4]:
## Algorithm to extract the best jet assignment from a set of predictions
# https://github.com/Alexanders101/SPANet/blob/master/spanet/network/prediction_selection.py


TArray = np.ndarray

TFloat32 = numba.types.float32
TInt64 = numba.types.int64

TPrediction = numba.typed.typedlist.ListType(TFloat32[::1])
TPredictions = numba.typed.typedlist.ListType(TFloat32[:, ::1])

TResult = TInt64[:, ::1]
TResults = TInt64[:, :, ::1]

NUMBA_DEBUG = False


if NUMBA_DEBUG:
    def njit(*args, **kwargs):
        def wrapper(function):
            return function
        return wrapper


@njit("void(float32[::1], int64, int64, float32)")
def mask_1(data, size, index, value):
    data[index] = value


@njit("void(float32[::1], int64, int64, float32)")
def mask_2(flat_data, size, index, value):
    data = flat_data.reshape((size, size))
    data[index, :] = value
    data[:, index] = value


@njit("void(float32[::1], int64, int64, float32)")
def mask_3(flat_data, size, index, value):
    data = flat_data.reshape((size, size, size))
    data[index, :, :] = value
    data[:, index, :] = value
    data[:, :, index] = value


# @njit("void(float32[::1], int64, int64, float32)")
# def mask_4(flat_data, size, index, value):
#     data = flat_data.reshape((size, size, size, size))
#     data[index, :, :, :] = value
#     data[:, index, :, :] = value
#     data[:, :, index, :] = value
#     data[:, :, :, index] = value


# @njit("void(float32[::1], int64, int64, float32)")
# def mask_5(flat_data, size, index, value):
#     data = flat_data.reshape((size, size, size, size, size))
#     data[index, :, :, :, :] = value
#     data[:, index, :, :, :] = value
#     data[:, :, index, :, :] = value
#     data[:, :, :, index, :] = value
#     data[:, :, :, :, index] = value
#
#
# @njit("void(float32[::1], int64, int64, float32)")
# def mask_6(flat_data, size, index, value):
#     data = flat_data.reshape((size, size, size, size, size, size))
#     data[index, :, :, :, :, :] = value
#     data[:, index, :, :, :, :] = value
#     data[:, :, index, :, :, :] = value
#     data[:, :, :, index, :, :] = value
#     data[:, :, :, :, index, :] = value
#     data[:, :, :, :, :, index] = value


# @njit("void(float32[::1], int64, int64, float32)")
# def mask_7(flat_data, size, index, value):
#     data = flat_data.reshape((size, size, size, size, size, size, size))
#     data[index, :, :, :, :, :, :] = value
#     data[:, index, :, :, :, :, :] = value
#     data[:, :, index, :, :, :, :] = value
#     data[:, :, :, index, :, :, :] = value
#     data[:, :, :, :, index, :, :] = value
#     data[:, :, :, :, :, index, :] = value
#     data[:, :, :, :, :, :, index] = value
#
#
# @njit("void(float32[::1], int64, int64, float32)")
# def mask_8(flat_data, size, index, value):
#     data = flat_data.reshape((size, size, size, size, size, size, size, size))
#     data[index, :, :, :, :, :, :, :] = value
#     data[:, index, :, :, :, :, :, :] = value
#     data[:, :, index, :, :, :, :, :] = value
#     data[:, :, :, index, :, :, :, :] = value
#     data[:, :, :, :, index, :, :, :] = value
#     data[:, :, :, :, :, index, :, :] = value
#     data[:, :, :, :, :, :, index, :] = value
#     data[:, :, :, :, :, :, :, index] = value


@njit("void(float32[::1], int64, int64, int64, float32)")
def mask_jet(data, num_partons, max_jets, index, value):
    if num_partons == 1:
        mask_1(data, max_jets, index, value)
    elif num_partons == 2:
        mask_2(data, max_jets, index, value)
    elif num_partons == 3:
        mask_3(data, max_jets, index, value)
    # elif num_partons == 4:
    #     mask_4(data, max_jets, index, value)
    # elif num_partons == 5:
    #     mask_5(data, max_jets, index, value)
    # elif num_partons == 6:
    #     mask_6(data, max_jets, index, value)
    # elif num_partons == 7:
    #     mask_7(data, max_jets, index, value)
    # elif num_partons == 8:
    #     mask_8(data, max_jets, index, value)


@njit("int64[::1](int64, int64)")
def compute_strides(num_partons, max_jets):
    strides = np.zeros(num_partons, dtype=np.int64)
    strides[-1] = 1
    for i in range(num_partons - 2, -1, -1):
        strides[i] = strides[i + 1] * max_jets

    return strides


@njit(TInt64[::1](TInt64, TInt64[::1]))
def unravel_index(index, strides):
    num_partons = strides.shape[0]
    result = np.zeros(num_partons, dtype=np.int64)

    remainder = index
    for i in range(num_partons):
        result[i] = remainder // strides[i]
        remainder %= strides[i]
    return result


@njit(TInt64(TInt64[::1], TInt64[::1]))
def ravel_index(index, strides):
    return (index * strides).sum()


@njit(numba.types.Tuple((TInt64, TInt64, TFloat32))(TPrediction))
def maximal_prediction(predictions):
    best_jet = -1
    best_prediction = -1
    best_value = -np.float32(np.inf)

    for i in range(len(predictions)):
        max_jet = np.argmax(predictions[i])
        max_value = predictions[i][max_jet]

        if max_value > best_value:
            best_prediction = i
            best_value = max_value
            best_jet = max_jet

    return best_jet, best_prediction, best_value


@njit(TResult(TPrediction, TInt64[::1], TInt64))
def extract_prediction(predictions, num_partons, max_jets):
    float_negative_inf = -np.float32(np.inf)
    max_partons = num_partons.max()
    num_targets = len(predictions)

    # Create copies of predictions for safety and calculate the output shapes
    strides = []
    for i in range(num_targets):
        strides.append(compute_strides(num_partons[i], max_jets))

    # Fill up the prediction matrix
    # -2 : Not yet assigned
    # -1 : Masked value
    # else : The actual index value
    results = np.zeros((num_targets, max_partons), np.int64) - 2

    for _ in range(num_targets):
        best_jet, best_prediction, best_value = maximal_prediction(predictions)

        if not np.isfinite(best_value):
            return results

        best_jets = unravel_index(best_jet, strides[best_prediction])

        results[best_prediction, :] = -1
        for i in range(num_partons[best_prediction]):
            results[best_prediction, i] = best_jets[i]

        predictions[best_prediction][:] = float_negative_inf
        for i in range(num_targets):
            for jet in best_jets:
                mask_jet(predictions[i], num_partons[i], max_jets, jet, float_negative_inf)

    return results


@njit(TResults(TPredictions, TInt64[::1], TInt64, TInt64), parallel=True)
def _extract_predictions(predictions, num_partons, max_jets, batch_size):
    output = np.zeros((batch_size, len(predictions), num_partons.max()), np.int64)
    predictions = [p.copy() for p in predictions]

    for batch in numba.prange(batch_size):
        current_prediction = numba.typed.List([prediction[batch] for prediction in predictions])
        output[batch, :, :] = extract_prediction(current_prediction, num_partons, max_jets)

    return np.ascontiguousarray(output.transpose((1, 0, 2)))


def extract_predictions(predictions: List[TArray]):
    flat_predictions = numba.typed.List([p.reshape((p.shape[0], -1)) for p in predictions])
    num_partons = np.array([len(p.shape) - 1 for p in predictions])
    max_jets = max(max(p.shape[1:]) for p in predictions)
    batch_size = max(p.shape[0] for p in predictions)

    results = _extract_predictions(flat_predictions, num_partons, max_jets, batch_size)
    return [result[:, :partons] for result, partons in zip(results, num_partons)]

In [5]:
file_name="/t3home/mmalucch//spanet_5jets_ptreg_ATLAS.onnx"

# load the model
session = onnxruntime.InferenceSession(
    file_name,
    providers=onnxruntime.get_available_providers()
)

# print the input/putput name and shape
input_name=[input.name for input in session.get_inputs()]
output_name=[output.name for output in session.get_outputs()]
print("Inputs name:", input_name)
print("Outputs name:", output_name)

input_shape=[input.shape for input in session.get_inputs()]
output_shape=[output.shape for output in session.get_outputs()]
print("Inputs shape:", input_shape)
print("Outputs shape:", output_shape)

Inputs name: ['Jet_data', 'Jet_mask']
Outputs name: ['h1_assignment_probability', 'h2_assignment_probability', 'h1_detection_probability', 'h2_detection_probability']
Inputs shape: [['batch_size', 'num_Jet', 4], ['batch_size', 'num_Jet']]
Outputs shape: [['Exph1_assignment_probability_dim_0', 'Exph1_assignment_probability_dim_1', 'Exph1_assignment_probability_dim_2'], ['Exph1_assignment_probability_dim_0', 'Exph1_assignment_probability_dim_1', 'Exph1_assignment_probability_dim_2'], ['Exph1_assignment_probability_dim_0'], ['Exph1_assignment_probability_dim_0']]


In [6]:
# load the file containing the test data
# with jet information and true assignment
filename_test="/work/mmalucch/out_hh4b/hh4b_9999_sin_cos_phi/output_JetGood_test.h5"
df_test = h5py.File(filename_test,'r')

# load the file containing the predictions by spanet to check if
# the extraction algorithm is working
filename_pred="/t3home/mmalucch/out_spanet_prediction_5jets_ptreg_ATLAS.h5"
df_pred = h5py.File(filename_pred,'r')

# Get the input varaibles
# get the log(x+1) of the pt to be fed as input
ptPnetRegNeutrino=np.log(np.array(df_test["INPUTS"]["Jet"]["ptPnetRegNeutrino"][()])+1)
eta=np.array(df_test["INPUTS"]["Jet"]["eta"][()])
phi=np.array(df_test["INPUTS"]["Jet"]["phi"][()])
btag=np.array(df_test["INPUTS"]["Jet"]["btag"][()])
# stack the input variables
input_complete=np.stack((ptPnetRegNeutrino,eta,phi,btag),axis=-1)
print(input_complete.shape)

# Get the mask containing the information
# about which jets are valid (since we consider 5 jets)
mask=np.array(df_test["INPUTS"]["Jet"]["MASK"][()])


# create the input dictionary
x={input_name[0]:   input_complete,
    input_name[1]:   mask}


(129710, 5, 4)


In [7]:
# run the model prediction
outputs = session.run(output_name, x)

In [10]:
print(len(outputs))
print(outputs[0].shape)
print(outputs[1].shape)
print(outputs[2].shape)
print(outputs[3].shape)
print(outputs[0][0])
print(outputs[1][0])

# get the true targets
idx_b1_test = df_test["TARGETS"]["h1"]["b1"][()]
idx_b2_test = df_test["TARGETS"]["h1"]["b2"][()]
idx_b3_test = df_test["TARGETS"]["h2"]["b3"][()]
idx_b4_test = df_test["TARGETS"]["h2"]["b4"][()]
print(idx_b1_test.shape)

# get the predicted targets
idx_b1_pred = df_pred["TARGETS"]["h1"]["b1"][()]
idx_b2_pred = df_pred["TARGETS"]["h1"]["b2"][()]
idx_b3_pred = df_pred["TARGETS"]["h2"]["b3"][()]
idx_b4_pred = df_pred["TARGETS"]["h2"]["b4"][()]

# get the predicted probabilities from the predictions file
h1_prob=df_pred["TARGETS"]["h1"]["assignment_probability"][()]
h2_prob=df_pred["TARGETS"]["h2"]["assignment_probability"][()]

for i in range(2):
    print(outputs[0][i])
    print(outputs[1][i])
    print("idx test", idx_b1_test[i], idx_b2_test[i], idx_b3_test[i], idx_b4_test[i])
    print("idx pred", idx_b1_pred[i], idx_b2_pred[i], idx_b3_pred[i], idx_b4_pred[i])
    print("h prob", h1_prob[i], h2_prob[i])

    print("\n")


4
(129710, 5, 5)
(129710, 5, 5)
(129710,)
(129710,)
[[0.0000000e+00 2.2281808e-04 1.1495069e-07 1.2401931e-04 9.9582564e-09]
 [2.2281808e-04 0.0000000e+00 8.3803292e-07 4.9964243e-01 1.8324774e-06]
 [1.1495069e-07 8.3803292e-07 0.0000000e+00 3.5860548e-06 1.4011460e-10]
 [1.2401931e-04 4.9964243e-01 3.5860548e-06 0.0000000e+00 4.3281375e-06]
 [9.9582564e-09 1.8324774e-06 1.4011460e-10 4.3281375e-06 0.0000000e+00]]
[[0.0000000e+00 6.6398010e-05 4.9939770e-01 2.8401175e-07 6.9324473e-05]
 [6.6398010e-05 0.0000000e+00 9.7741169e-05 2.6958940e-11 5.6810501e-07]
 [4.9939770e-01 9.7741169e-05 0.0000000e+00 5.4201127e-08 3.6791520e-04]
 [2.8401175e-07 2.6958940e-11 5.4201127e-08 0.0000000e+00 5.7382394e-11]
 [6.9324473e-05 5.6810501e-07 3.6791520e-04 5.7382394e-11 0.0000000e+00]]
(129710,)
[[0.0000000e+00 2.2281808e-04 1.1495069e-07 1.2401931e-04 9.9582564e-09]
 [2.2281808e-04 0.0000000e+00 8.3803292e-07 4.9964243e-01 1.8324774e-06]
 [1.1495069e-07 8.3803292e-07 0.0000000e+00 3.5860548e-06 1.

In [18]:
# extract the best jet assignment from
# the predicted probabilities
assignment_probability=[outputs[0],outputs[1]]
x=extract_predictions(assignment_probability)
print("x",x)
# reshape the output
y=np.concatenate([x[0],x[1]],axis=-1)
print("y",y)

# stack the assignments from the prediction file
pred_conc=np.stack((idx_b1_pred,idx_b2_pred,idx_b3_pred,idx_b4_pred),axis=-1)
print("pred_conc",pred_conc)

x [array([[1, 3],
       [0, 3],
       [1, 2],
       ...,
       [0, 1],
       [0, 2],
       [2, 3]]), array([[0, 2],
       [1, 2],
       [0, 3],
       ...,
       [2, 3],
       [1, 3],
       [0, 1]])]
y [[1 3 0 2]
 [0 3 1 2]
 [1 2 0 3]
 ...
 [0 1 2 3]
 [0 2 1 3]
 [2 3 0 1]]
pred_conc [[1 3 0 2]
 [0 3 1 2]
 [1 2 0 3]
 ...
 [0 1 2 3]
 [0 2 1 3]
 [2 3 0 1]]
