In [20]:
from typing import Optional, List, Dict

import numpy as np
import scipy

import tritonclient.grpc as triton_grpc
import tritonclient.http as triton_http

from tqdm import tqdm

In [10]:
%load_ext autoreload
%autoreload 2

In [25]:
# from https://github.com/lgray/hgg-coffea/blob/triton-bdts/src/hgg_coffea/tools/chained_quantile.py
class wrapped_triton:
    def __init__(
        self,
        model_url: str,
    ) -> None:
        fullprotocol, location = model_url.split("://")
        _, protocol = fullprotocol.split("+")
        address, model, version = location.split("/")

        self._protocol = protocol
        self._address = address
        self._model = model
        self._version = version

    def __call__(self, input_dict: Dict[str, np.ndarray]) -> np.ndarray:
        if self._protocol == "grpc":
            client = triton_grpc.InferenceServerClient(url=self._address, verbose=False)
            triton_protocol = triton_grpc
        elif self._protocol == "http":
            client = triton_http.InferenceServerClient(
                url=self._address,
                verbose=False,
                concurrency=12,
            )
            triton_protocol = triton_http
        else:
            raise ValueError(f"{self._protocol} does not encode a valid protocol (grpc or http)")

        # Infer
        inputs = []

        for key in input_dict:
            input = triton_protocol.InferInput(key, input_dict[key].shape, "FP32")
            input.set_data_from_numpy(input_dict[key])
            inputs.append(input)

        output = triton_protocol.InferRequestedOutput("softmax")

        request = client.infer(
            self._model,
            model_version=self._version,
            inputs=inputs,
            outputs=[output],
        )

        out = request.as_numpy("softmax")

        return out

In [13]:
batch_size = 4
# pfs = 100
# svs = 7
pfs = 128
svs = 10
np.random.seed(42)

# input_dict = {
#     "pf_points": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

input_dict = {
    "pf_features": np.random.rand(batch_size, 25, pfs).astype("float32"),
    "pf_vectors": np.random.rand(batch_size, 4, pfs).astype("float32"),
    "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
    "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
    "sv_vectors": np.random.rand(batch_size, 4, svs).astype("float32"),
    "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
}

# input_dict = {
#     "pf_points__0": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features__1": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask__2": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points__3": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features__4": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask__5": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

In [29]:
# model_url = "triton+grpc://ailab01.fnal.gov:8001/particlenet_hww/1"
# model_url = "triton+grpc://prp-gpu-1.t2.ucsd.edu:8001/particlenet_hww/1"
# model_url = "triton+grpc://67.58.49.52:8001/ak8_MD_vminclv2ParT_manual_fixwrap/1"
model_url = "triton+grpc://67.58.49.52:8001/2023May30_ak8_MD_inclv8_part_2reg_manual/1"
triton_model = wrapped_triton(model_url)
for i in tqdm(range(1)):
    output = triton_model(input_dict)
print(output)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.34it/s]

[[ -9.643893   -11.078968    -6.423805   ...  -5.95106      0.4450065
    0.25934374]
 [-11.023443   -12.613604    -7.5247912  ...  -7.6548076    0.6109172
    0.35383254]
 [ -5.1560373   -7.3972573   -2.6038532  ...  -2.6032228    0.46883643
    0.20065397]
 [ -9.264318   -11.314613    -6.1544614  ...  -5.425825     0.66927564
    0.41242862]]





In [19]:
import onnx
import onnxruntime as ort


model_dir = (
    "models/model_2023May30/ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/"
)

onnx_model = onnx.load(model_dir + "model.onnx")
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(model_dir + "model.onnx")
outputs = ort_sess.run(None, input_dict)[0]
print("ONNX outputs:", outputs)
print("Shape:", outputs.shape)
print("Softmax applied:", np.allclose(np.sum(outputs, axis=1), 1, atol=1e-5))

# Print Result
# predicted, actual = classes[outputs[0][0].argmax(0)], classes[y]
# print(f'Predicted: "{predicted}", Actual: "{actual}"')

ONNX outputs: [[ -9.643891   -11.078964    -6.423809   ...  -5.951058     0.44500652
    0.25934368]
 [-11.023443   -12.613605    -7.5247912  ...  -7.654809     0.61091703
    0.35383236]
 [ -5.1560345   -7.3972545   -2.6038537  ...  -2.60322      0.46883655
    0.2006541 ]
 [ -9.264317   -11.314613    -6.1544585  ...  -5.425826     0.6692767
    0.4124295 ]]
Shape: (4, 316)
Softmax applied: False


In [14]:
import onnx
import onnxruntime as ort


model_dir = (
    "models/model_2023May30/ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/"
)

onnx_model = onnx.load(model_dir + "model.onnx")
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession(model_dir + "model.onnx")
outputs = ort_sess.run(None, input_dict)
print("ONNX outputs:", outputs[0])
print("Shape:", outputs[0].shape)
print("Softmax applied:", np.allclose(np.sum(outputs[0], axis=1), 1, atol=1e-5))

# Print Result
# predicted, actual = classes[outputs[0][0].argmax(0)], classes[y]
# print(f'Predicted: "{predicted}", Actual: "{actual}"')

ONNX outputs: [[7.8597404e-06 1.8713886e-06 1.9673070e-04 ... 3.1563517e-04
  1.8921728e-01 1.5715510e-01]
 [5.0401678e-07 1.0276534e-07 1.6668266e-05 ... 1.4636063e-05
  5.6909341e-02 4.4008151e-02]
 [3.9485330e-04 4.1984258e-05 5.0679673e-03 ... 5.0711799e-03
  1.0946776e-01 8.3717473e-02]
 [3.6939423e-06 4.7539822e-07 8.2810249e-05 ... 1.7160318e-04
  7.6136842e-02 5.8890816e-02]]
Shape: (4, 316)
Softmax applied: True


In [16]:
print("ONNX cls output sums:", np.sum(outputs[0][:, :-2], axis=1))

ONNX cls outputs: [0.65362775 0.8990826  0.8068148  0.8649726 ]


2023May30 Model:

In [3]:
import torch
from ParticleTransformer import ParticleTransformerTagger

part_model = ParticleTransformerTagger(
    pf_input_dim=25,
    sv_input_dim=11,
    num_classes=314,  # one dim for regression
    # network configurations
    pair_input_dim=4,
    embed_dims=[128, 512, 128],
    pair_embed_dims=[64, 64, 64],
    num_heads=8,
    num_layers=8,
    num_cls_layers=2,
    block_params=None,
    cls_block_params={"dropout": 0, "attn_dropout": 0, "activation_dropout": 0},
    fc_params=[],
    activation="gelu",
    # misc
    trim=True,
    for_inference=True,
).cpu()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model_dir = (
    "models/model_2023May30/ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/"
)
part_model.load_state_dict(
    torch.load(model_dir + "net_best_epoch_state.pt", map_location=torch.device("cpu"))
)
_ = part_model.eval()

data_config = {
    "input_names": ["pf_features", "pf_vectors", "pf_mask", "sv_features", "sv_vectors", "sv_mask"],
    "input_shapes": {
        "pf_features": (-1, 25, pfs),
        "pf_vectors": (-1, 4, pfs),
        "pf_mask": (-1, 1, pfs),
        "sv_features": (-1, 11, svs),
        "sv_vectors": (-1, 4, svs),
        "sv_mask": (-1, 1, svs),
    },
}

model_info = {
    "input_names": list(data_config["input_names"]),
    "input_shapes": {k: ((1,) + s[1:]) for k, s in data_config["input_shapes"].items()},
    "output_names": ["softmax"],
    "dynamic_axes": {
        **{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config["input_names"]},
        **{"softmax": {0: "N"}},
    },
}

inputs = tuple(
    torch.ones(model_info["input_shapes"][k], dtype=torch.float32)
    for k in model_info["input_names"]
)
torch.onnx.export(
    part_model,
    inputs,
    model_dir + "ak8_MD_inclv8_part_2reg_manual.useamp.lite.gm5.ddp-bs768-lr6p75e-3/model.onnx",
    input_names=model_info["input_names"],
    output_names=model_info["output_names"],
    dynamic_axes=model_info.get("dynamic_axes", None),
    opset_version=11,
)

RuntimeError: Error(s) in loading state_dict for ParticleTransformerTagger:
	Missing key(s) in state_dict: "part.fc.0.weight", "part.fc.0.bias". 
	Unexpected key(s) in state_dict: "part.fc.1.weight", "part.fc.1.bias", "part.fc.0.0.weight", "part.fc.0.0.bias". 
	size mismatch for pf_embed.embed.1.weight: copying a param with shape torch.Size([64, 25]) from checkpoint, the shape in current model is torch.Size([128, 25]).
	size mismatch for pf_embed.embed.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for pf_embed.embed.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for pf_embed.embed.3.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for pf_embed.embed.4.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for pf_embed.embed.4.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for pf_embed.embed.6.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for pf_embed.embed.6.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for pf_embed.embed.7.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for pf_embed.embed.7.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sv_embed.embed.1.weight: copying a param with shape torch.Size([64, 11]) from checkpoint, the shape in current model is torch.Size([128, 11]).
	size mismatch for sv_embed.embed.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sv_embed.embed.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sv_embed.embed.3.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for sv_embed.embed.4.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for sv_embed.embed.4.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sv_embed.embed.6.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sv_embed.embed.6.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for sv_embed.embed.7.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for sv_embed.embed.7.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_token: copying a param with shape torch.Size([1, 1, 64]) from checkpoint, the shape in current model is torch.Size([1, 1, 128]).
	size mismatch for part.pair_embed.embed.1.weight: copying a param with shape torch.Size([32, 4, 1]) from checkpoint, the shape in current model is torch.Size([64, 4, 1]).
	size mismatch for part.pair_embed.embed.1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.2.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.2.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.2.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.4.weight: copying a param with shape torch.Size([32, 32, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1]).
	size mismatch for part.pair_embed.embed.4.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.5.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.5.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.5.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.5.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.7.weight: copying a param with shape torch.Size([32, 32, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1]).
	size mismatch for part.pair_embed.embed.7.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.8.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.8.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.8.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.8.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for part.pair_embed.embed.10.weight: copying a param with shape torch.Size([8, 32, 1]) from checkpoint, the shape in current model is torch.Size([8, 64, 1]).
	size mismatch for part.blocks.0.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.0.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.0.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.0.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.0.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.0.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.0.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.0.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.0.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.0.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.1.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.1.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.1.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.1.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.1.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.1.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.1.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.1.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.1.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.2.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.2.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.2.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.2.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.2.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.2.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.2.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.2.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.2.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.3.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.3.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.3.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.3.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.3.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.3.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.3.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.3.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.3.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.4.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.4.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.4.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.4.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.4.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.4.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.4.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.4.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.4.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.5.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.5.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.5.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.5.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.5.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.5.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.5.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.5.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.5.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.6.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.6.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.6.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.6.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.6.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.6.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.6.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.6.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.6.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.blocks.7.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.blocks.7.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.blocks.7.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.blocks.7.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.blocks.7.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.7.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.7.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.blocks.7.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.blocks.7.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.cls_blocks.0.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.cls_blocks.0.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.cls_blocks.0.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.0.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.cls_blocks.0.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.cls_blocks.0.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.cls_blocks.0.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.cls_blocks.0.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.cls_blocks.0.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.w_resid: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.pre_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.pre_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.attn.in_proj_weight: copying a param with shape torch.Size([192, 64]) from checkpoint, the shape in current model is torch.Size([384, 128]).
	size mismatch for part.cls_blocks.1.attn.in_proj_bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for part.cls_blocks.1.attn.out_proj.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for part.cls_blocks.1.attn.out_proj.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.post_attn_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.post_attn_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.pre_fc_norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.pre_fc_norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.cls_blocks.1.fc1.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for part.cls_blocks.1.fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.cls_blocks.1.post_fc_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.cls_blocks.1.post_fc_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for part.cls_blocks.1.fc2.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for part.cls_blocks.1.fc2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.norm.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for part.norm.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).

In [None]:
in_tensors = [torch.Tensor(val) for key, val in input_dict.items()]
in_tensors[2] = in_tensors[2].bool()
in_tensors[5] = in_tensors[5].bool()

out = part_model(*in_tensors)
print(out)

In [None]:
output

In [None]:
out

Dec22 Model:

In [None]:
import torch
from ParticleTransformerHidden import ParticleTransformerTagger

part_model = ParticleTransformerTagger(
    pf_input_dim=25,
    sv_input_dim=11,
    num_classes=37 + 1,  # one dim for regression
    # network configurations
    pair_input_dim=4,
    embed_dims=[128, 512, 128],
    pair_embed_dims=[64, 64, 64],
    num_heads=8,
    num_layers=8,
    num_cls_layers=2,
    block_params=None,
    cls_block_params={"dropout": 0, "attn_dropout": 0, "activation_dropout": 0},
    fc_params=[],
    activation="gelu",
    # misc
    trim=True,
    for_inference=True,
).cpu()

In [None]:
part_model.load_state_dict(torch.load("net_best_epoch_state.pt", map_location=torch.device("cpu")))
_ = part_model.eval()

data_config = {
    "input_names": ["pf_features", "pf_vectors", "pf_mask", "sv_features", "sv_vectors", "sv_mask"],
    "input_shapes": {
        "pf_features": (-1, 25, pfs),
        "pf_vectors": (-1, 4, pfs),
        "pf_mask": (-1, 1, pfs),
        "sv_features": (-1, 11, svs),
        "sv_vectors": (-1, 4, svs),
        "sv_mask": (-1, 1, svs),
    },
}

model_info = {
    "input_names": list(data_config["input_names"]),
    "input_shapes": {k: ((1,) + s[1:]) for k, s in data_config["input_shapes"].items()},
    "output_names": ["softmax"],
    "dynamic_axes": {
        **{k: {0: "N", 2: "n_" + k.split("_")[0]} for k in data_config["input_names"]},
        **{"softmax": {0: "N"}},
    },
}

inputs = tuple(
    torch.ones(model_info["input_shapes"][k], dtype=torch.float32)
    for k in model_info["input_names"]
)
torch.onnx.export(
    part_model,
    inputs,
    "model.onnx",
    input_names=model_info["input_names"],
    output_names=model_info["output_names"],
    dynamic_axes=model_info.get("dynamic_axes", None),
    opset_version=11,
)

In [None]:
in_tensors = [torch.Tensor(val) for key, val in input_dict.items()]
in_tensors[2] = in_tensors[2].bool()
in_tensors[5] = in_tensors[5].bool()

out = part_model(*in_tensors)
print(out)

In [None]:
output