In [2]:
import onnx
import onnxruntime as ort
import numpy as np
from pathlib import Path

In [3]:
print(ort.__version__)

1.19.0


In [4]:
def inspect_and_apply(model_path: str, n_track_features: int, none_input: bool = False):
    model = onnx.load(model_path)
    onnx.checker.check_model(model)
    onnx.helper.printable_graph(model.graph)

    sess_opts = ort.SessionOptions()
    sess_opts.enable_mem_reuse = False
    ort_sess = ort.InferenceSession(model_path, sess_opts=sess_opts)
    input_names = [x.name for x in ort_sess.get_inputs()]
    input_shapes = [x.shape for x in ort_sess.get_inputs()]
    output_names = [x.name for x in ort_sess.get_outputs()]
    output_shapes = [x.shape for x in ort_sess.get_outputs()]
    print("Input names: ", input_names, "with shapes: ", input_shapes)
    print("Output names: ", output_names, "with shapes: ", output_shapes)


    jet_features = np.array([[85507.8, -3.05748]], dtype=np.float32)
    track_features = np.array([]*n_track_features, dtype=np.float32) \
        if none_input \
        else np.array([0.]*n_track_features, dtype=np.float32)
    track_features = track_features.reshape(-1, n_track_features)
    print("jet feature shape:", jet_features.shape)
    print("track feature shape:", track_features.shape)

    results = ort_sess.run(output_names, {input_names[0]: jet_features, input_names[1]: track_features})
    return results


In [8]:
def apply_GN2Xv02(model_path: str, no_track: bool = False):
    model = onnx.load(model_path)
    onnx.checker.check_model(model)
    onnx.helper.printable_graph(model.graph)

    sess_opts = ort.SessionOptions()
    sess_opts.enable_mem_reuse = False
    ort_sess = ort.InferenceSession(model_path, sess_opts=sess_opts)
    input_names = [x.name for x in ort_sess.get_inputs()]
    input_shapes = [x.shape for x in ort_sess.get_inputs()]
    output_names = [x.name for x in ort_sess.get_outputs()]
    output_shapes = [x.shape for x in ort_sess.get_outputs()]
    
    print("Input names: ", input_names, "with shapes: ", input_shapes)
    print("Output names: ", output_names, "with shapes: ", output_shapes)

    n_track_features = 20
    n_flow_features = 4
    jet_features = np.array([[110447.257812, -0.707432]], dtype=np.float32)

    track_features = np.array([ 2.182648, -2.773842, -0.333382, 0.275881, 0.001136, 25.440155, 19.198353, 0.002501, 0.003137, 0.000012, 5.000000, 7.000000, 2.000000, 2.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -1.395097, 3.133203, -0.025952, -0.021848, 0.001272, 8.388472, 14.784507, 0.005152, 0.004862, 0.000063, 3.000000, 4.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000, -0.506528, -0.087251, 0.079731, -0.386179, 0.001050, 3.007987, -0.438205, 0.004584, 0.002977, 0.000023, 4.000000, 8.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.163305, 0.007536, 0.291742, -0.125659, 0.000915, 1.351799, 0.053117, 0.003588, 0.002836, 0.000015, 4.000000, 8.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.079823, 0.116521, -0.401949, -0.237931, -0.001205, 0.572685, 0.676673, 0.004132, 0.003108, 0.000021, 5.000000, 10.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.106855, -0.194486, -0.365867, 0.260800, 0.001341, 0.541493, 0.698768, 0.004380, 0.004495, 0.000015, 3.000000, 6.000000, 0.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.029123, 0.155005, -0.122931, -0.619130, 0.000474, 0.337255, 1.430664, 0.002586, 0.001446, 0.000014, 4.000000, 5.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.007765, -0.075620, -0.143919, 0.344221, 0.001568, 0.066863, 0.361370, 0.003319, 0.004787, 0.000018, 3.000000, 7.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.019261, -4.957779, -0.266907, -0.814137, 0.000603, -0.127223, -25.857431, 0.004545, 0.002051, 0.000015, 4.000000, 8.000000, 2.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.021749, -0.059028, -0.370248, 0.084571, -0.000994, -0.255405, 0.393371, 0.002464, 0.002833, 0.000013, 4.000000, 8.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.057833, -0.056124, 0.414234, -0.721329, -0.000693, -0.330855, -0.307921, 0.005263, 0.002382, 0.000015, 3.000000, 10.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.040038, -0.086086, 0.247437, -0.251351, 0.000477, -0.535966, -0.951690, 0.001935, 0.001457, 0.000009, 4.000000, 8.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.047530, -0.083239, 0.417712, -0.388799, 0.000410, -0.705951, -0.967510, 0.001948, 0.001267, 0.000007, 3.000000, 8.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.132460, 0.041486, -0.397067, -0.857597, 0.000670, -0.781745, 0.202548, 0.005007, 0.002233, 0.000019, 4.000000, 8.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.132073, 0.239654, -0.501569, 0.300782, -0.001325, -1.325551, -1.615456, 0.002904, 0.003399, 0.000021, 4.000000, 9.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, -0.182409, 0.054681, 0.010012, 1.066626, -0.001872, -1.379352, -0.275418, 0.003934, 0.004710, 0.000033, 3.000000, 9.000000, 1.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,], dtype=np.float32)
    track_features = track_features.reshape(-1, n_track_features)

    flow_features = np.array([66414.562500, 81433.679688, 0.046980, -0.000448, 8230.662109, 9636.583008, 0.130956, -0.073864, 4774.290527, 4801.318359, 0.601076, -0.100902, 4275.401855, 4995.911133, 0.134731, 0.136726, 2947.914062, 5193.184570, -0.459449, 0.424706, 1464.797485, 2440.677734, -0.388799, 0.417712, 1402.786377, 2103.119873, -0.251351, 0.247437, 1349.344727, 2609.448730, -0.570472, -0.134479, 1088.754395, 1408.979614, -0.041884, 0.014473, 1045.537476, 2113.177246, -0.619130, -0.122931, 1036.354980, 2337.252686, -0.745750, 0.604241, 947.011108, 947.611267, 0.671832, -0.437875, 918.046814, 928.262268, 0.558389, -0.763152, 838.013550, 1015.529846, 0.084571, -0.370248, 799.070312, 1101.653931, -0.125659, 0.291742, 696.601196, 767.787903, 0.300782, -0.501569, 676.801575, 758.388794, 0.260800, -0.365867, 653.504395, 1448.737793, -0.721329, 0.414234, 597.778320, 652.739746, 0.344221, -0.143919, 597.612915, 1498.124878, -0.857597, -0.397067, 573.536011, 962.256226, -0.386179, 0.079731, 560.360718, 841.627625, -0.237931, -0.401949, 449.192841, 629.345520, -0.160663, 0.002935], dtype=np.float32)
    flow_features = flow_features.reshape(-1, n_flow_features)
    print("track feature shape:", track_features.shape)
    print("flow feature shape:", flow_features.shape)

    results = ort_sess.run(output_names, {
        input_names[0]: jet_features, 
        input_names[1]: track_features,
        input_names[2]: flow_features
        })
    return results


In [6]:
model_path_base = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev"
m1_config = [f"{model_path_base}/BTagging/20230307/gn2v00/antiktvr30rmax4rmin02track/network.onnx", 21]
m2_config = [f"{model_path_base}/BTagging/20231205/GN2v01/antikt4empflow/network_fold0.onnx", 19]
m3_config = [f"{model_path_base}/BTagging/20240726/GN2Xv02/antikt10ufo/network.onnx", 20]

In [9]:
apply_GN2Xv02(m3_config[0])

Input names:  ['jet_features', 'track_features', 'flow_features'] with shapes:  [[1, 2], ['n_tracks', 20], ['n_flow', 4]]
Output names:  ['GN2Xv02_phbb', 'GN2Xv02_phcc', 'GN2Xv02_ptop', 'GN2Xv02_pqcd'] with shapes:  [[], [], [], []]
track feature shape: (16, 20)
flow feature shape: (23, 4)


[array(nan, dtype=float32),
 array(nan, dtype=float32),
 array(nan, dtype=float32),
 array(nan, dtype=float32)]

In [8]:
inspect_and_apply(*m1_config)

Input names:  ['jet_features', 'track_features'] with shapes:  [[1, 2], ['n_tracks', 21]]
Output names:  ['pu', 'pc', 'pb'] with shapes:  [[1], [1], [1]]
jet feature shape: (1, 2)
track feature shape: (1, 21)


[array([0.8281662], dtype=float32),
 array([0.10067604], dtype=float32),
 array([0.07115781], dtype=float32)]

In [6]:
inspect_and_apply(*m1_config, none_input=True)

Input names:  ['jet_features', 'track_features'] with shapes:  [[1, 2], ['n_tracks', 21]]
Output names:  ['pu', 'pc', 'pb'] with shapes:  [[1], [1], [1]]
jet feature shape: (1, 2)
track feature shape: (0, 21)


[array([0.6114677], dtype=float32),
 array([0.26959956], dtype=float32),
 array([0.11893278], dtype=float32)]

In [11]:
inspect_and_apply(*m2_config)

Input names:  ['jet_features', 'track_features'] with shapes:  [[1, 2], ['n_tracks', 19]]
Output names:  ['GN2v01_pb', 'GN2v01_pc', 'GN2v01_pu', 'GN2v01_ptau', 'GN2v01_TrackOrigin', 'GN2v01_VertexIndex'] with shapes:  [[], [], [], [], ['n_tracks'], ['n_tracks']]
jet feature shape: (1, 2)
track feature shape: (1, 19)


[array(0.00410247, dtype=float32),
 array(0.07487475, dtype=float32),
 array(0.5267351, dtype=float32),
 array(0.39428762, dtype=float32),
 array([6], dtype=int8),
 array([0], dtype=int8)]

In [9]:
inspect_and_apply(*m2_config, none_input=True)

Input names:  ['jet_features', 'track_features'] with shapes:  [[1, 2], ['n_tracks', 19]]
Output names:  ['GN2v01_pb', 'GN2v01_pc', 'GN2v01_pu', 'GN2v01_ptau', 'GN2v01_TrackOrigin', 'GN2v01_VertexIndex'] with shapes:  [[], [], [], [], ['n_tracks'], ['n_tracks']]
jet feature shape: (1, 2)
track feature shape: (0, 19)


[1;31m2024-09-13 14:32:15.132987045 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Shape node. Name:'/Shape' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {2} != {3}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.
[m


RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Shape node. Name:'/Shape' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {2} != {3}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.


In [None]:
# numbers are taken from: 
# https://gitlab.cern.ch/atlas/athena/-/merge_requests/69721

# jet_features = np.array([[11516.2, 1.0935]], dtype=np.float32)

# track_features = np.array([-0.113795, 0.173364, -0.230979, -0.233828, 0.000914455, 1.1331, 1.3538, 0.00291881, 0.00241196, 1.61412e-05, 4, 8, 1, 1, 0, 0, 0, 0, 0, 0, 0, -0.126863, 0.00371974, 0.17269, 0.135184, -0.000857426, 0.904411, -0.0226558, 0.00404071, 0.00238144, 1.47576e-05, 4, 7, 1, 1, 0, 0, 0, 0, 0, 0, 1], dtype=np.float32)
# track_features = track_features.reshape(-1, n_track_features)

# print("jet feature shape:", jet_features.shape)
# print("track feature shape:", track_features.shape)

# pu, pc, pb = ort_sess.run(output_names, {input_names[0]: jet_features, input_names[1]: track_features})

# print(f"pu: {pu}")
# print(f"pc: {pc}")
# print(f"pb: {pb}")