In [None]:
from typing import Optional, List, Dict

import numpy as np
import scipy

import tritonclient.grpc as triton_grpc
import tritonclient.http as triton_http

from tqdm import tqdm

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from https://github.com/lgray/hgg-coffea/blob/triton-bdts/src/hgg_coffea/tools/chained_quantile.py
class wrapped_triton:
    def __init__(
        self,
        model_url: str,
    ) -> None:
        fullprotocol, location = model_url.split("://")
        _, protocol = fullprotocol.split("+")
        address, model, version = location.split("/")

        self._protocol = protocol
        self._address = address
        self._model = model
        self._version = version

    def __call__(self, input_dict: Dict[str, np.ndarray]) -> np.ndarray:
        if self._protocol == "grpc":
            client = triton_grpc.InferenceServerClient(url=self._address, verbose=False)
            triton_protocol = triton_grpc
        elif self._protocol == "http":
            client = triton_http.InferenceServerClient(
                url=self._address,
                verbose=False,
                concurrency=12,
            )
            triton_protocol = triton_http
        else:
            raise ValueError(f"{self._protocol} does not encode a valid protocol (grpc or http)")

        # Infer
        inputs = []

        for key in input_dict:
            input = triton_protocol.InferInput(key, input_dict[key].shape, "FP32")
            input.set_data_from_numpy(input_dict[key])
            inputs.append(input)

        output = triton_protocol.InferRequestedOutput("softmax")

        request = client.infer(
            self._model,
            model_version=self._version,
            inputs=inputs,
            outputs=[output],
        )

        out = request.as_numpy("softmax")

        return out

In [None]:
batch_size = 4
# pfs = 100
# svs = 7
pfs = 128
svs = 10
np.random.seed(42)

# input_dict = {
#     "pf_points": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

input_dict = {
    "pf_features": np.random.rand(batch_size, 25, pfs).astype("float32"),
    "pf_vectors": np.random.rand(batch_size, 4, pfs).astype("float32"),
    "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
    "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
    "sv_vectors": np.random.rand(batch_size, 4, svs).astype("float32"),
    "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
}

# input_dict = {
#     "pf_points__0": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features__1": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask__2": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points__3": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features__4": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask__5": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

In [None]:
# model_url = "triton+grpc://ailab01.fnal.gov:8001/particlenet_hww/1"
# model_url = "triton+grpc://prp-gpu-1.t2.ucsd.edu:8001/particlenet_hww/1"
# model_url = "triton+grpc://67.58.49.48:8001/ak8_MD_vminclv2ParT_manual_fixwrap/1"
model_url = "triton+grpc://67.58.49.48:8001/2023May30_ak8_MD_inclv8_part_2reg_manual/1"
triton_model = wrapped_triton(model_url)
for i in tqdm(range(1)):
    output = triton_model(input_dict)
print(output)

In [None]:
import onnx
import onnxruntime as ort


model_dir = (
    "models/model_2023May30/ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/"
)

onnx_model = onnx.load(model_dir + "model.onnx")
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(model_dir + "model.onnx")
outputs = ort_sess.run(None, input_dict)[0]
print("ONNX outputs:", outputs)
print("Shape:", outputs.shape)
print("Softmax applied:", np.allclose(np.sum(outputs, axis=1), 1, atol=1e-5))

# Print Result
# predicted, actual = classes[outputs[0][0].argmax(0)], classes[y]
# print(f'Predicted: "{predicted}", Actual: "{actual}"')

In [None]:
import onnx
import onnxruntime as ort


model_dir = (
    "models/model_2023May30/ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/"
)

onnx_model = onnx.load(model_dir + "model.onnx")
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(model_dir + "model.onnx")
outputs = ort_sess.run(None, input_dict)
print("ONNX outputs:", outputs[0])
print("Shape:", outputs[0].shape)
print("Softmax applied:", np.allclose(np.sum(outputs[0], axis=1), 1, atol=1e-5))

# Print Result
# predicted, actual = classes[outputs[0][0].argmax(0)], classes[y]
# print(f'Predicted: "{predicted}", Actual: "{actual}"')

In [None]:
print("ONNX cls output sums:", np.sum(outputs[0][:, :-2], axis=1))

2023May30 Model:

In [None]:
import torch
from ParticleTransformer import ParticleTransformerTagger

part_model = ParticleTransformerTagger(
    pf_input_dim=25,
    sv_input_dim=11,
    num_classes=314,  # one dim for regression
    # network configurations
    pair_input_dim=4,
    embed_dims=[128, 512, 128],
    pair_embed_dims=[64, 64, 64],
    num_heads=8,
    num_layers=8,
    num_cls_layers=2,
    block_params=None,
    cls_block_params={"dropout": 0, "attn_dropout": 0, "activation_dropout": 0},
    fc_params=[],
    activation="gelu",
    # misc
    trim=True,
    for_inference=True,
).cpu()

In [None]:
model_dir = (
    "models/model_2023May30/ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/"
)
part_model.load_state_dict(
    torch.load(model_dir + "net_best_epoch_state.pt", map_location=torch.device("cpu"))
)
_ = part_model.eval()

data_config = {
    "input_names": ["pf_features", "pf_vectors", "pf_mask", "sv_features", "sv_vectors", "sv_mask"],
    "input_shapes": {
        "pf_features": (-1, 25, pfs),
        "pf_vectors": (-1, 4, pfs),
        "pf_mask": (-1, 1, pfs),
        "sv_features": (-1, 11, svs),
        "sv_vectors": (-1, 4, svs),
        "sv_mask": (-1, 1, svs),
    },
}

model_info = {
    "input_names": list(data_config["input_names"]),
    "input_shapes": {k: ((1,) + s[1:]) for k, s in data_config["input_shapes"].items()},
    "output_names": ["softmax"],
    "dynamic_axes": {
        **{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config["input_names"]},
        **{"softmax": {0: "N"}},
    },
}

inputs = tuple(
    torch.ones(model_info["input_shapes"][k], dtype=torch.float32)
    for k in model_info["input_names"]
)
torch.onnx.export(
    part_model,
    inputs,
    model_dir + "ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/model.onnx",
    input_names=model_info["input_names"],
    output_names=model_info["output_names"],
    dynamic_axes=model_info.get("dynamic_axes", None),
    opset_version=11,
)

In [None]:
in_tensors = [torch.Tensor(val) for key, val in input_dict.items()]
in_tensors[2] = in_tensors[2].bool()
in_tensors[5] = in_tensors[5].bool()

out = part_model(*in_tensors)
print(out)

In [None]:
output

In [None]:
out

Dec22 Model:

In [None]:
import torch
from ParticleTransformerHidden import ParticleTransformerTagger

part_model = ParticleTransformerTagger(
    pf_input_dim=25,
    sv_input_dim=11,
    num_classes=37 + 1,  # one dim for regression
    # network configurations
    pair_input_dim=4,
    embed_dims=[128, 512, 128],
    pair_embed_dims=[64, 64, 64],
    num_heads=8,
    num_layers=8,
    num_cls_layers=2,
    block_params=None,
    cls_block_params={"dropout": 0, "attn_dropout": 0, "activation_dropout": 0},
    fc_params=[],
    activation="gelu",
    # misc
    trim=True,
    for_inference=True,
).cpu()

In [None]:
part_model.load_state_dict(torch.load("net_best_epoch_state.pt", map_location=torch.device("cpu")))
_ = part_model.eval()

data_config = {
    "input_names": ["pf_features", "pf_vectors", "pf_mask", "sv_features", "sv_vectors", "sv_mask"],
    "input_shapes": {
        "pf_features": (-1, 25, pfs),
        "pf_vectors": (-1, 4, pfs),
        "pf_mask": (-1, 1, pfs),
        "sv_features": (-1, 11, svs),
        "sv_vectors": (-1, 4, svs),
        "sv_mask": (-1, 1, svs),
    },
}

model_info = {
    "input_names": list(data_config["input_names"]),
    "input_shapes": {k: ((1,) + s[1:]) for k, s in data_config["input_shapes"].items()},
    "output_names": ["softmax"],
    "dynamic_axes": {
        **{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config["input_names"]},
        **{"softmax": {0: "N"}},
    },
}

inputs = tuple(
    torch.ones(model_info["input_shapes"][k], dtype=torch.float32)
    for k in model_info["input_names"]
)
torch.onnx.export(
    part_model,
    inputs,
    "model.onnx",
    input_names=model_info["input_names"],
    output_names=model_info["output_names"],
    dynamic_axes=model_info.get("dynamic_axes", None),
    opset_version=11,
)

In [None]:
in_tensors = [torch.Tensor(val) for key, val in input_dict.items()]
in_tensors[2] = in_tensors[2].bool()
in_tensors[5] = in_tensors[5].bool()

out = part_model(*in_tensors)
print(out)

In [None]:
output