In [1]:
import onnx
import onnxruntime as ort
import numpy as np

In [2]:
print(ort.__version__)

1.15.1


In [17]:
def inspect_model(model_path: str, n_track_features: int, none_input: bool = False):
    model = onnx.load(model_path)
    onnx.checker.check_model(model)
    onnx.helper.printable_graph(model.graph)

    ort_sess = ort.InferenceSession(model_path)
    input_names = [x.name for x in ort_sess.get_inputs()]
    input_shapes = [x.shape for x in ort_sess.get_inputs()]
    output_names = [x.name for x in ort_sess.get_outputs()]
    output_shapes = [x.shape for x in ort_sess.get_outputs()]
    print("Input names: ", input_names, "with shapes: ", input_shapes)
    print("Output names: ", output_names, "with shapes: ", output_shapes)


    jet_features = np.array([[85507.8, -3.05748]], dtype=np.float32)
    track_features = np.array([]*n_track_features, dtype=np.float32) \
        if none_input \
        else np.array([0.]*n_track_features, dtype=np.float32)
    track_features = track_features.reshape(-1, n_track_features)
    print("jet feature shape:", jet_features.shape)
    print("track feature shape:", track_features.shape)

    results = ort_sess.run(output_names, {input_names[0]: jet_features, input_names[1]: track_features})
    return results


In [18]:
m1_config = ["/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/BTagging/20230307/gn2v00/antiktvr30rmax4rmin02track/network.onnx", 21]
m2_config = ["/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/BTagging/20231205/GN2v01/antikt4empflow/network_fold0.onnx", 19]

In [19]:
inspect_model(*m1_config)

Input names:  ['jet_features', 'track_features'] with shapes:  [[1, 2], ['n_tracks', 21]]
Output names:  ['pu', 'pc', 'pb'] with shapes:  [[1], [1], [1]]
jet feature shape: (1, 2)
track feature shape: (1, 21)


[array([0.8281662], dtype=float32),
 array([0.10067604], dtype=float32),
 array([0.07115781], dtype=float32)]

In [20]:
inspect_model(*m2_config)

Input names:  ['jet_features', 'track_features'] with shapes:  [[1, 2], ['n_tracks', 19]]
Output names:  ['GN2v01_pb', 'GN2v01_pc', 'GN2v01_pu', 'GN2v01_ptau', 'GN2v01_TrackOrigin', 'GN2v01_VertexIndex'] with shapes:  [[], [], [], [], ['n_tracks'], ['n_tracks']]
jet feature shape: (1, 2)
track feature shape: (1, 19)


2024-04-09 17:19:35.128450888 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer '1070'. It is not used by any node and should be removed from the model.
2024-04-09 17:19:35.128470579 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer '1064'. It is not used by any node and should be removed from the model.
2024-04-09 17:19:35.128480419 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer '1101'. It is not used by any node and should be removed from the model.
2024-04-09 17:19:35.128484189 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer '1095'. It is not used by any node and should be removed from the model.
2024-04-09 17:19:35.128491009 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer '1126'. It is not used by any node and should be removed from the model.
2024-04-09 17:19:35.128494459 [W:onnxruntime:

[array(0.00410247, dtype=float32),
 array(0.07487475, dtype=float32),
 array(0.5267351, dtype=float32),
 array(0.39428762, dtype=float32),
 array([6], dtype=int8),
 array([0], dtype=int8)]

In [11]:
# numbers are taken from: 
# https://gitlab.cern.ch/atlas/athena/-/merge_requests/69721

# jet_features = np.array([[11516.2, 1.0935]], dtype=np.float32)

# track_features = np.array([-0.113795, 0.173364, -0.230979, -0.233828, 0.000914455, 1.1331, 1.3538, 0.00291881, 0.00241196, 1.61412e-05, 4, 8, 1, 1, 0, 0, 0, 0, 0, 0, 0, -0.126863, 0.00371974, 0.17269, 0.135184, -0.000857426, 0.904411, -0.0226558, 0.00404071, 0.00238144, 1.47576e-05, 4, 7, 1, 1, 0, 0, 0, 0, 0, 0, 1], dtype=np.float32)
# track_features = track_features.reshape(-1, n_track_features)

# print("jet feature shape:", jet_features.shape)
# print("track feature shape:", track_features.shape)

# pu, pc, pb = ort_sess.run(output_names, {input_names[0]: jet_features, input_names[1]: track_features})

# print(f"pu: {pu}")
# print(f"pc: {pc}")
# print(f"pb: {pb}")