In [1]:
from typing import Optional, List, Dict

import numpy as np

import tritonclient.grpc as triton_grpc
import tritonclient.http as triton_http


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:

# from https://github.com/lgray/hgg-coffea/blob/triton-bdts/src/hgg_coffea/tools/chained_quantile.py
class wrapped_triton:
    def __init__(
        self,
        model_url: str,
    ) -> None:
        fullprotocol, location = model_url.split("://")
        _, protocol = fullprotocol.split("+")
        address, model, version = location.split("/")

        self._protocol = protocol
        self._address = address
        self._model = model
        self._version = version

    def __call__(self, input_dict: Dict[str, np.ndarray]) -> np.ndarray:
        if self._protocol == "grpc":
            client = triton_grpc.InferenceServerClient(url=self._address, verbose=False)
            triton_protocol = triton_grpc
        elif self._protocol == "http":
            client = triton_http.InferenceServerClient(
                url=self._address,
                verbose=False,
                concurrency=12,
            )
            triton_protocol = triton_http
        else:
            raise ValueError(f"{self._protocol} does not encode a valid protocol (grpc or http)")

        # Infer
        inputs = []

        for key in input_dict:
            input = triton_protocol.InferInput(key, input_dict[key].shape, "FP32")
            input.set_data_from_numpy(input_dict[key])
            inputs.append(input)

        output = triton_protocol.InferRequestedOutput("softmax")

        request = client.infer(
            self._model,
            model_version=self._version,
            inputs=inputs,
            outputs=[output],
        )

        out = request.as_numpy("softmax")

        return out

In [25]:
batch_size = 880
# pfs = 100
# svs = 7
pfs = 128
svs = 10

# input_dict = {
#     "pf_points": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

input_dict = {
    "pf_features": np.random.rand(batch_size, 25, pfs).astype("float32"),
    "pf_vectors": np.random.rand(batch_size, 4, pfs).astype("float32"),
    "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
    "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
    "sv_vectors": np.random.rand(batch_size, 4, svs).astype("float32"),
    "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
}

# input_dict = {
#     "pf_points__0": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features__1": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask__2": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points__3": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features__4": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask__5": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

In [26]:
# model_url = "triton+grpc://ailab01.fnal.gov:8001/particlenet_hww/1"
# model_url = "triton+grpc://prp-gpu-1.t2.ucsd.edu:8001/particlenet_hww/1"
model_url = "triton+grpc://67.58.49.52:8001/ak8_MD_vminclv2ParT_manual_fixwrap/1"
# model_url = "triton+grpc://localhost:8001/particlenet_hww_ul_4q_3q/1"
# model_url = "triton+grpc://67.58.49.52:8001/particlenet_hww_ul_4q_3q/1"
triton_model = wrapped_triton(model_url)
output = triton_model(input_dict)
print(output)

[[1.4366233e-03 2.7789180e-03 1.4810998e-03 ... 1.3419604e-03
  5.3415606e-03 1.0536162e-03]
 [5.8962428e-04 2.6495671e-03 2.1507419e-03 ... 3.6234359e-04
  3.9466871e-05 1.1135796e-02]
 [1.8111541e-03 6.1054532e-03 4.3359781e-03 ... 4.5285108e-03
  4.9895938e-03 4.3286658e-03]
 ...
 [2.6687947e-03 5.2311928e-03 2.8158592e-03 ... 3.7654082e-03
  1.0825735e-02 2.2049018e-03]
 [2.7392500e-03 2.7328818e-03 1.0838965e-03 ... 2.2412995e-03
  1.9178806e-03 1.2975576e-03]
 [1.9648573e-03 4.6280641e-03 3.3738450e-03 ... 3.7505146e-04
  7.2018191e-04 2.7100766e-02]]


In [12]:
import onnx
import onnxruntime as ort

onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession('model.onnx')
outputs = ort_sess.run(None, input_dict)
print(outputs)

# Print Result 
# predicted, actual = classes[outputs[0][0].argmax(0)], classes[y]
# print(f'Predicted: "{predicted}", Actual: "{actual}"')

[array([[2.40560714e-03, 4.73927893e-03, 4.42884630e-03, 7.70280650e-03,
        7.44594727e-03, 4.89646792e-02, 2.44864757e-04, 1.84084594e-04,
        3.34733166e-03, 5.55951335e-03, 2.30525038e-04, 5.30063233e-04,
        1.57437345e-03, 1.06911850e-03, 2.14584306e-05, 1.38277428e-05,
        2.32464038e-02, 5.26734889e-02, 9.97571088e-03, 1.04071777e-02,
        3.24207067e-05, 3.28148562e-05, 7.48393745e-07, 7.66754225e-02,
        9.75757986e-02, 1.65003791e-01, 7.47575983e-02, 1.00998513e-01,
        1.26040773e-02, 4.62465100e-02, 6.31558970e-02, 1.60297364e-01,
        6.35669276e-04, 1.13147628e-02, 3.94817034e-04, 9.87346284e-04,
        4.52140672e-03],
       [1.02613051e-03, 3.39163793e-03, 2.07269844e-03, 8.19048751e-03,
        1.25318337e-02, 4.28423323e-02, 6.32008596e-04, 6.25702844e-04,
        1.71326788e-03, 1.57124607e-03, 5.84401190e-04, 2.10922817e-03,
        1.14760280e-03, 5.08062483e-04, 1.14137052e-04, 8.69628202e-05,
        1.43004835e-01, 4.31774035e-02

In [68]:
outputs[0].shape

(10, 166)

In [69]:
import torch
from ParticleTransformerHidden import ParticleTransformerTagger

part_model = ParticleTransformerTagger(
    pf_input_dim=25,
    sv_input_dim=11,
    num_classes=37 + 1, # one dim for regression
    # network configurations
    pair_input_dim=4,
    embed_dims=[128, 512, 128],
    pair_embed_dims=[64, 64, 64],
    num_heads=8,
    num_layers=8,
    num_cls_layers=2,
    block_params=None,
    cls_block_params={'dropout': 0, 'attn_dropout': 0, 'activation_dropout': 0},
    fc_params=[],
    activation='gelu',
    # misc
    trim=True,
    for_inference=True,
).cpu()

In [70]:
part_model.load_state_dict(torch.load("net_best_epoch_state.pt", map_location=torch.device("cpu")))
_ = part_model.eval()

data_config = {
    "input_names": ["pf_features", "pf_vectors", "pf_mask", "sv_features", "sv_vectors", "sv_mask"],
    "input_shapes": {
        "pf_features": (-1, 25, pfs),
        "pf_vectors": (-1, 4, pfs),
        "pf_mask": (-1, 1, pfs),
        "sv_features": (-1, 11, svs),
        "sv_vectors": (-1, 4, svs),
        "sv_mask": (-1, 1, svs),
    },
}

model_info = {
    "input_names": list(data_config["input_names"]),
    "input_shapes": {k: ((1,) + s[1:]) for k, s in data_config["input_shapes"].items()},
    "output_names": ["softmax"],
    "dynamic_axes": {
        **{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config["input_names"]},
        **{"softmax": {0: "N"}},
    },
}

inputs = tuple(
    torch.ones(model_info["input_shapes"][k], dtype=torch.float32)
    for k in model_info["input_names"]
)
torch.onnx.export(
    part_model,
    inputs,
    "model.onnx",
    input_names=model_info["input_names"],
    output_names=model_info["output_names"],
    dynamic_axes=model_info.get("dynamic_axes", None),
    opset_version=11,
)


  assert embed_dim == embed_dim_to_check, \
  assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
  assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
  if attn_mask.shape != correct_3d_size:
  assert key_padding_mask.shape == (bsz, src_len), \
  q = q / math.sqrt(E)


In [59]:
in_tensors = [torch.Tensor(val) for key, val in input_dict.items()]
in_tensors[2] = in_tensors[2].bool()
in_tensors[5] = in_tensors[5].bool()

out = part_model(*in_tensors)
print(out)

tensor([[ 1.5299e-03,  2.9319e-03,  1.9769e-03,  5.4878e-03,  7.8053e-03,
          1.2931e-02,  1.3580e-04,  6.4316e-05,  5.6223e-03,  3.5439e-03,
          5.7096e-05,  7.1246e-05,  2.4305e-03,  7.1927e-04,  4.5472e-05,
          2.3293e-05,  7.5924e-02,  3.5818e-02,  8.9561e-03,  9.3922e-03,
          4.7991e-06,  2.2116e-05,  4.7249e-07,  9.0073e-02,  4.4058e-02,
          1.9110e-01,  5.3661e-02,  8.5326e-02,  2.5597e-02,  6.6441e-02,
          1.0755e-01,  1.1822e-01,  5.6389e-04,  2.4701e-02,  1.2583e-03,
          8.0306e-04,  1.5152e-02,  1.1746e+02, -1.4494e+00, -3.8877e+00,
         -4.1724e-01, -2.1862e+00,  4.6684e-01,  2.2150e-01,  2.5049e-01,
          1.4588e+00,  5.7455e-01,  6.8124e-01, -2.3397e+00, -7.7287e-01,
          7.0747e-01,  1.9500e+00, -3.2054e-01, -3.3750e-01, -3.7859e-01,
         -3.5869e-01,  8.4656e-01,  2.2947e+00, -5.7608e-01,  1.1714e+00,
         -9.2651e-02, -8.0103e-01,  2.5926e-02,  1.4265e+00, -2.4823e+00,
          1.6136e+00,  1.8963e+00,  6.



In [53]:
output

array([[ 1.52987312e-03,  2.93191988e-03,  1.97693706e-03,
         5.48777496e-03,  7.80526781e-03,  1.29312817e-02,
         1.35801572e-04,  6.43163876e-05,  5.62225794e-03,
         3.54393735e-03,  5.70958327e-05,  7.12459951e-05,
         2.43050908e-03,  7.19271193e-04,  4.54720794e-05,
         2.32931434e-05,  7.59244114e-02,  3.58182155e-02,
         8.95608403e-03,  9.39224847e-03,  4.79914434e-06,
         2.21154678e-05,  4.72491791e-07,  9.00733918e-02,
         4.40584160e-02,  1.91101983e-01,  5.36605977e-02,
         8.53259340e-02,  2.55966950e-02,  6.64411336e-02,
         1.07548602e-01,  1.18220098e-01,  5.63886017e-04,
         2.47012507e-02,  1.25828688e-03,  8.03059200e-04,
         1.51520669e-02,  1.17462372e+02, -1.44942522e+00,
        -3.88765955e+00, -4.17238146e-01, -2.18617630e+00,
         4.66837943e-01,  2.21496135e-01,  2.50485390e-01,
         1.45877016e+00,  5.74553132e-01,  6.81239963e-01,
        -2.33971190e+00, -7.72875249e-01,  7.07470357e-0

In [51]:
out

tensor([[ 8.1155e-04,  2.0606e-03,  1.5020e-03,  2.6562e-03,  9.1622e-03,
          3.2599e-03,  1.4781e-03,  9.5434e-04,  3.3028e-04,  6.9261e-04,
          8.2008e-04,  9.9679e-04,  1.7369e-04,  2.3844e-04,  3.0242e-04,
          3.9385e-04,  4.7146e-03,  1.8696e-03,  1.0828e-03,  8.7911e-04,
          1.6338e-04,  4.5681e-05,  9.7771e-06,  1.5454e-01,  3.1454e-02,
          3.7411e-02,  1.8826e-02,  9.6202e-02,  1.1759e-02,  4.5734e-02,
          2.2030e-01,  1.1428e-01,  6.0745e-02,  5.3709e-02,  2.1316e-02,
          4.7606e-02,  5.1513e-02,  4.4246e+01, -2.1041e-02, -3.4620e+00,
         -5.6576e-02,  2.3017e-01, -4.3493e-01, -4.8431e-01,  5.6465e-02,
         -5.0499e-01,  8.3038e-02,  2.3239e-01, -1.3889e+00,  1.7120e+00,
          1.3771e-01, -1.2663e+00,  6.6230e-01,  7.4898e-02, -3.1588e-01,
          3.6870e-01, -6.8668e-01,  1.6357e+00, -8.1999e-01,  2.1713e-01,
          6.4083e-01,  8.7478e-01,  9.4118e-01, -4.9381e-01, -1.2368e+00,
          1.0730e+00,  1.7268e+00,  3.