# Triton & ONNX
- runs inference on triton
- runs inference from onnx

In [1]:
from typing import Optional, List, Dict

import nump# Exploring the weights
- processes an `events[year][ch][sample]` object using `make_events_dict()`
- investigate e.g. PU weightsy as np

import tritonclient.grpc as triton_grpc
import tritonclient.http as triton_http

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# from https://github.com/lgray/hgg-coffea/blob/triton-bdts/src/hgg_coffea/tools/chained_quantile.py
class wrapped_triton:
    def __init__(
        self,
        model_url: str,
    ) -> None:
        fullprotocol, location = model_url.split("://")
        _, protocol = fullprotocol.split("+")
        address, model, version = location.split("/")

        self._protocol = protocol
        self._address = address
        self._model = model
        self._version = version

    def __call__(self, input_dict: Dict[str, np.ndarray]) -> np.ndarray:
        if self._protocol == "grpc":
            client = triton_grpc.InferenceServerClient(url=self._address, verbose=False)
            triton_protocol = triton_grpc
        elif self._protocol == "http":
            client = triton_http.InferenceServerClient(
                url=self._address,
                verbose=False,
                concurrency=12,
            )
            triton_protocol = triton_http
        else:
            raise ValueError(f"{self._protocol} does not encode a valid protocol (grpc or http)")

        # Infer
        inputs = []

        for key in input_dict:
            input = triton_protocol.InferInput(key, input_dict[key].shape, "FP32")
            input.set_data_from_numpy(input_dict[key])
            inputs.append(input)

        output = triton_protocol.InferRequestedOutput("softmax")

        request = client.infer(
            self._model,
            model_version=self._version,
            inputs=inputs,
            outputs=[output],
        )

        out = request.as_numpy("softmax")

        return out

In [26]:
batch_size = 5
pfs = 128
svs = 10

input_dict = {
    "pf_features": np.random.rand(batch_size, 25, pfs).astype("float32"),
    "pf_vectors": np.random.rand(batch_size, 4, pfs).astype("float32"),
    "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
    "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
    "sv_vectors": np.random.rand(batch_size, 4, svs).astype("float32"),
    "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
}

In [27]:
# run mode
# model_url = "triton+grpc://67.58.49.52:8001/ak8_MD_vminclv2ParT_manual_fixwrap/1"
model_url = "triton+grpc://67.58.49.52:8001/ak8_MD_vminclv2ParT_manual_fixwrap_all_nodes/1"
# model_url = "triton+grpc://localhost:8001/particlenet_hww_ul_4q_3q/1"
# model_url = "triton+grpc://67.58.49.52:8001/particlenet_hww_ul_4q_3q/1"
triton_model = wrapped_triton(model_url)
for i in tqdm(range(1)):
    output = triton_model(input_dict)
# print(output)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.49it/s]


In [28]:
output

array([[ 1.04044715e-03,  4.19112574e-03,  3.08266841e-03,
         6.77172840e-03,  1.09085375e-02,  3.74644659e-02,
         1.10819901e-03,  9.51661787e-04,  1.39991054e-03,
         1.44925353e-03,  5.86194510e-04,  2.19112681e-03,
         4.15044517e-04,  1.64529265e-04,  1.50348336e-04,
         1.44833684e-04,  1.03220969e-01,  3.88819650e-02,
         3.99476103e-03,  6.21663593e-03,  2.28791123e-05,
         1.08433517e-06,  5.32731008e-07,  9.52191055e-02,
         3.68980728e-02,  1.73061192e-01,  4.37834822e-02,
         3.16666700e-02,  4.26374413e-02,  9.03337821e-02,
         1.36058077e-01,  1.01278037e-01,  6.36519492e-03,
         5.52568957e-03,  2.80237477e-03,  8.84246919e-03,
         1.16950611e-03,  9.75663147e+01, -2.21926308e+00,
        -4.10098600e+00,  1.90281332e-01, -1.59158313e+00,
         3.93356979e-02, -1.21985324e-01,  1.87750086e-01,
         1.10192883e+00,  4.50458944e-01,  4.61917818e-01,
        -2.35008550e+00,  2.47295156e-01,  8.22852626e-0

In [25]:
output[0].shape

(166,)

# Run model using onnx

In [31]:
import onnx
import onnxruntime as ort

batch_size = 20
pfs = 128
svs = 10

input_dict = {
    "pf_features": np.random.rand(batch_size, 25, pfs).astype("float32"),
    "pf_vectors": np.random.rand(batch_size, 4, pfs).astype("float32"),
    "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
    "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
    "sv_vectors": np.random.rand(batch_size, 4, svs).astype("float32"),
    "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
}

onnx_model = onnx.load("/Users/fmokhtar/projects/weaver-core-dev/ak8_MD_vminclv2ParT_manual_fixwrap/1/model.onnx")
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession("/Users/fmokhtar/projects/weaver-core-dev/ak8_MD_vminclv2ParT_manual_fixwrap/1/model.onnx")
outputs = ort_sess.run(None, input_dict)
print(outputs)

[array([[4.86524962e-03, 4.97966399e-03, 1.74506661e-03, 1.17051192e-02,
        6.81609195e-03, 5.58614638e-03, 2.07647705e-03, 8.02133349e-04,
        1.28070323e-03, 9.03823646e-04, 8.59912427e-04, 1.36731251e-03,
        5.05559612e-04, 1.48216175e-04, 1.17055039e-04, 7.71649793e-05,
        4.24696133e-02, 3.44527699e-02, 1.47807617e-02, 2.23821755e-02,
        1.27572479e-04, 6.49630329e-06, 1.83118209e-06, 5.21351621e-02,
        3.65206562e-02, 1.85362101e-01, 8.91993344e-02, 2.09410697e-01,
        2.67438330e-02, 3.88626419e-02, 1.29468903e-01, 4.75959405e-02,
        6.96596084e-03, 5.73533494e-03, 2.12198449e-03, 9.53700114e-03,
        2.28361320e-03],
       [2.21705972e-03, 4.84761875e-03, 4.28972626e-03, 4.10481263e-03,
        6.57369848e-03, 1.64161697e-02, 6.01862092e-04, 4.78196889e-04,
        1.61183663e-02, 1.34124737e-02, 1.45546772e-04, 2.76738923e-04,
        8.05218797e-03, 2.30695214e-03, 1.15670846e-05, 9.05202251e-06,
        5.45549653e-02, 8.37783888e-02

In [32]:
outputs[0].shape

(20, 37)