In [1]:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
from loguru import logger

import torch
from torch import nn

from yolox.exp import get_exp
from yolox.models.network_blocks import SiLU
from yolox.utils import replace_module


def make_parser():
    parser = argparse.ArgumentParser("YOLOX onnx deploy")
    parser.add_argument(
        "--output-name", type=str, default="yolox.onnx", help="output name of models"
    )
    parser.add_argument(
        "--input", default="images", type=str, help="input node name of onnx model"
    )
    parser.add_argument(
        "--output", default="output", type=str, help="output node name of onnx model"
    )
    parser.add_argument(
        "-o", "--opset", default=11, type=int, help="onnx opset version"
    )
    parser.add_argument("--batch-size", type=int, default=1, help="batch size")
    parser.add_argument(
        "--dynamic",
        action="store_true",
        help="whether the input shape should be dynamic or not",
    )
    parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
    parser.add_argument(
        "-f",
        "--exp_file",
        default=None,
        type=str,
        help="experiment description file",
    )
    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    parser.add_argument(
        "--decode_in_inference", action="store_true", help="decode in inference or not"
    )

    return parser


args = make_parser().parse_args("")
args.exp_file = "./exps/yolov/yolov_l.py"
args.name = "yolov_l"
args.ckpt = "./yolov_l.pth"
args.output_name = "yolov_l.onnx"


logger.info("args value: {}".format(args))
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)

if not args.experiment_name:
    args.experiment_name = exp.exp_name

model = exp.get_model()
ckpt_file = args.ckpt

# load the model state dict
ckpt = torch.load(ckpt_file, map_location="cuda")

model.eval()
if "model" in ckpt:
    ckpt = ckpt["model"]
model.load_state_dict(ckpt)
model = replace_module(model, nn.SiLU, SiLU)
model.head.decode_in_inference = args.decode_in_inference
model.to("cuda").eval()
logger.info("loading checkpoint done.")

# For export body
dummy_input = torch.randn(5, 3, exp.test_size[0], exp.test_size[1]).cuda()

# For export head
# dummy_input = {
#     "pred_result": torch.randn(2, 30, 37).cuda(),
#     "pred_idx": torch.randn(2, 30).cuda(),
# }

# input_names = ["pred_result", "pred_idx"]
# output_names = ["output"]
output_names=["decode_res", "before_nms_features", "before_nms_regf", "outputs_decode"]
torch.onnx.export(
    model,
    dummy_input,
    args.output_name,
    input_names=["images"],
    output_names=output_names,
    dynamic_axes={"images": {0: "batch"}, 
                  "decode_res": {0: "batch"},
                  "before_nms_features": {0: "batch"},
                  "before_nms_regf": {0: "batch"},
                  "outputs_decode": {0: "batch"},},
    verbose=True,
    do_constant_folding=True,
    export_params=True,
    # opset_version=args.opset,
)
logger.info("generated onnx model named {}".format(args.output_name))

# Load the ONNX model
model = onnx.load(args.output_name)

# Get the model's graph
graph = model.graph

# Get the list of outputs
outputs = graph.output

# Remove the desired output
# Replace 'output_name' with the name of the output you want to remove
del_idx = []
for i, output in enumerate(outputs):
    if output.name not in output_names:
        del_idx.append(i)
        logger.info("delete output {}".format(output.name))
        
del_idx.reverse()
for i in del_idx:
    del outputs[i]

# Save the modified model
onnx.save(model, args.output_name)

import onnx
from onnxsim import simplify

# use onnx-simplifier to reduce reduent model.
onnx_model = onnx.load(args.output_name)
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, args.output_name)
logger.info("generated simplified onnx model named {}".format(args.output_name))


[32m2023-12-08 14:12:54.597[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m68[0m - [1margs value: Namespace(output_name='yolov_l.onnx', input='images', output='output', opset=11, batch_size=1, dynamic=False, no_onnxsim=False, exp_file='./exps/yolov/yolov_l.py', experiment_name=None, name='yolov_l', ckpt='./yolov_l.pth', opts=[], decode_in_inference=False)[0m
[32m2023-12-08 14:12:56.342[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1mloading checkpoint done.[0m
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  output = [None for _ in range(len(prediction))]
  output_index = [None for _ in range(len(prediction))]
  for i, image_pred in enumerate(prediction):
  if not image_pred.size(0):


: 

In [88]:
import torch
import torch.nn as nn
import torchvision

class WoClass(nn.Module):
    def __init__(self, num_classes, Prenum, topK, nms_thre) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.Prenum = Prenum
        self.topK = topK
        self.nms_thre = nms_thre
        
    def forward(self, image_pred):
        # Get score and class with highest confidence
        class_conf, class_pred = torch.max(
            image_pred[:, 5 : 5 + self.num_classes], 1, keepdim=True
        )

        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
        detections = torch.cat(
            (
                image_pred[:, :5],
                class_conf,
                class_pred.float(),
                image_pred[:, 5 : 5 + self.num_classes],
            ),
            1,
        )

        conf_score = image_pred[:, 4]
        top_pre = torch.topk(conf_score, k=self.Prenum)
        sort_idx = top_pre.indices[: self.Prenum]
        detections_temp = detections[sort_idx, :]
        nms_out_index = torchvision.ops.batched_nms(
            detections_temp[:, :4],
            detections_temp[:, 4] * detections_temp[:, 5],
            detections_temp[:, 6],
            self.nms_thre,
        )

        topk_idx = sort_idx[nms_out_index[: self.topK]]
        return detections[topk_idx, :], topk_idx
    
woClass = WoClass(num_classes=30, 
                #   Prenum=750,
                  Prenum=256,
                  topK=30,
                  nms_thre=0.75)
dummy_input = torch.randn(5376, 35)
woClass.eval()
woClass(dummy_input)

input_names = ["image_pred"]
output_names = ["output", "output_index"]
dynamic_axes = {"image_pred": {0: "batch"}, "output": {0: "batch"}, "output_index": {0: "batch"}}

dummy_input = torch.randn(5376, 35)

# torch.onnx.export(
#     woClass,
#     dummy_input,
#     "WoClass.onnx",
#     input_names=input_names,
#     output_names=output_names,
#     dynamic_axes=dynamic_axes,
#     verbose=True,
#     do_constant_folding=True,
#     export_params=True,
#     # opset_version=args.opset,
# )

                  

In [89]:
def postpro_woclass(
        prediction, 
        # num_classes, 
        # nms_thre=0.75, 
        # topK=75, 
        # features=None,
    ):
        # Prenum = 750
        box_corner = prediction.new(prediction.shape)
        box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
        box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
        box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
        box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
        prediction[:, :, :4] = box_corner[:, :, :4]
        outputs = [None for _ in range(len(prediction))]
        output_indexs = [None for _ in range(len(prediction))]

        for i, image_pred in enumerate(prediction):
            if not image_pred.size(0):
                continue
            output, output_index = woClass(image_pred)
            outputs[i] = output
            output_indexs[i] = output_index
        
        outputs = torch.stack(outputs)
        output_indexs = torch.stack(output_indexs)
        return outputs, output_indexs
    
def find_feature_score(
    features,
    idxs,
    reg_features,
    predictions=None,
):
    features_cls = []
    features_reg = []
    cls_scores = []
    fg_scores = []
    simN = 30
    for i, feature in enumerate(features):
        features_cls.append(feature[idxs[i][: simN]])
        features_reg.append(reg_features[i, idxs[i][: simN]])
        cls_scores.append(predictions[i][: simN, 5])
        fg_scores.append(predictions[i][: simN, 4])
    features_cls = torch.cat(features_cls)
    features_reg = torch.cat(features_reg)
    cls_scores = torch.cat(cls_scores)
    fg_scores = torch.cat(fg_scores)
    return features_cls, features_reg, cls_scores, fg_scores

In [93]:
decode_res = torch.randn(5, 256, 35)
pred_result, pred_idx = postpro_woclass(decode_res)

In [91]:
pred_result.shape, pred_idx.shape

(torch.Size([5, 30, 37]), torch.Size([5, 30]))

In [13]:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import argparse
import os
from loguru import logger

import torch
from torch import nn

from yolox.exp import get_exp
from yolox.models.network_blocks import SiLU
from yolox.utils import replace_module


def make_parser():
    parser = argparse.ArgumentParser("YOLOX onnx deploy")
    parser.add_argument(
        "--output-name", type=str, default="yolox.onnx", help="output name of models"
    )
    parser.add_argument(
        "--input", default="images", type=str, help="input node name of onnx model"
    )
    parser.add_argument(
        "--output", default="output", type=str, help="output node name of onnx model"
    )
    parser.add_argument(
        "-o", "--opset", default=11, type=int, help="onnx opset version"
    )
    parser.add_argument("--batch-size", type=int, default=1, help="batch size")
    parser.add_argument(
        "--dynamic",
        action="store_true",
        help="whether the input shape should be dynamic or not",
    )
    parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
    parser.add_argument(
        "-f",
        "--exp_file",
        default=None,
        type=str,
        help="experiment description file",
    )
    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    parser.add_argument(
        "--decode_in_inference", action="store_true", help="decode in inference or not"
    )

    return parser


args = make_parser().parse_args("")
args.exp_file = "./exps/yolov/yolov_l.py"
args.name = "yolov_l"
args.ckpt = "./yolov_l.pth"
args.output_name = "yolov_l.onnx"


logger.info("args value: {}".format(args))
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)

if not args.experiment_name:
    args.experiment_name = exp.exp_name

model = exp.get_model()
ckpt_file = args.ckpt

# load the model state dict
ckpt = torch.load(ckpt_file, map_location="cuda")

model.eval()
if "model" in ckpt:
    ckpt = ckpt["model"]
model.load_state_dict(ckpt)
model = replace_module(model, nn.SiLU, SiLU)
model.head.decode_in_inference = args.decode_in_inference
model.to("cuda")
logger.info("loading checkpoint done.")

# For export body
# dummy_input = torch.randn(3, 3, exp.test_size[0], exp.test_size[1]).cuda()

# For export head
dummy_input = {
    "pred_result": torch.randn(5, 30, 37).cuda(),
    "pred_idx": torch.randn(5, 30).cuda(),
    "before_nms_features": torch.randn(5, 256, 64, 64).cuda(),
    "before_nms_regf": torch.randn(5, 256, 32, 32).cuda(),
    "outputs_decode": torch.randn(5, 256, 16, 16).cuda(),
}

input_names = ["pred_result", 
               "pred_idx",
               "before_nms_features",
               "before_nms_regf",
               "outputs_decode",]
output_names = ["pred_result", "fc_output"]
dynamic_axes = {"pred_result": {0: "batch"}, "fc_output": {0: "batch"}}
torch.onnx.export(
    model.head,
    dummy_input,
    "yolo_l_head.onnx",
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
    verbose=True,
    do_constant_folding=True,
    export_params=True,
    # opset_version=args.opset,
)
logger.info("generated onnx model named {}".format("yolo_l_head.onnx"))


import onnx
from onnxsim import simplify

# use onnx-simplifier to reduce reduent model.
onnx_model = onnx.load("yolo_l_head.onnx")
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, "yolo_l_head.onnx")
logger.info("generated simplified onnx model named {}".format("yolo_l_head.onnx"))


[32m2023-12-08 14:48:12.977[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m68[0m - [1margs value: Namespace(output_name='yolov_l.onnx', input='images', output='output', opset=11, batch_size=1, dynamic=False, no_onnxsim=False, exp_file='./exps/yolov/yolov_l.py', experiment_name=None, name='yolov_l', ckpt='./yolov_l.pth', opts=[], decode_in_inference=False)[0m
[32m2023-12-08 14:48:14.135[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1mloading checkpoint done.[0m
  [x.flatten(start_dim=2) for x in before_nms_features], dim=2
  [x.flatten(start_dim=2) for x in before_nms_regf], dim=2


IndexError: tensors used as indices must be long, int, byte or bool tensors

In [92]:
before_nms_features = torch.randn(5, 256, 64, 64)
before_nms_regf = torch.randn(5, 256, 32, 32)
cls_feat_flatten = torch.cat(
    [x.flatten(start_dim=2) for x in before_nms_features], dim=2
).permute(
    0, 2, 1
)  # [b,features,channels]
reg_feat_flatten = torch.cat(
    [x.flatten(start_dim=2) for x in before_nms_regf], dim=2
).permute(0, 2, 1)
features_cls, features_reg, cls_scores, fg_scores = find_feature_score(
            cls_feat_flatten, pred_idx, reg_feat_flatten, pred_result
        )
features_reg = features_reg.unsqueeze(0)
features_cls = features_cls.unsqueeze(0)
cls_scores = cls_scores.to(cls_feat_flatten.dtype)
fg_scores = fg_scores.to(cls_feat_flatten.dtype)


IndexError: index 227 is out of bounds for dimension 0 with size 160

In [31]:
cls_feat_flatten.shape, pred_idx.shape, reg_feat_flatten.shape, pred_result.shape

(torch.Size([256, 320, 64]),
 torch.Size([5, 30]),
 torch.Size([256, 160, 32]),
 torch.Size([5, 30, 37]))

In [62]:
import onnxruntime as ort
import numpy as np
import cv2
ort_session = ort.InferenceSession("./yolov_l_body.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name

img = cv2.imread("./assets/dog.jpg")
img = cv2.resize(img, (512, 512))
imgs = [img, img]
imgs = np.stack(imgs, axis=0)
imgs = imgs.transpose(0, 3, 1, 2)
imgs = imgs.astype(np.float32)
imgs /= 255.0
imgs = np.ascontiguousarray(imgs)

ort_inputs = {input_name: imgs}
ort_outs = ort_session.run(None, ort_inputs)


In [78]:
decode_res, before_nms_features, before_nms_regf, outputs_decode, _ = ort_outs
decode_res = torch.from_numpy(decode_res)
before_nms_features = torch.from_numpy(before_nms_features)
before_nms_regf = torch.from_numpy(before_nms_regf)
outputs_decode = torch.from_numpy(outputs_decode)


In [79]:
pred_result, pred_idx = postpro_woclass(decode_res)
cls_feat_flatten = torch.cat(
    [x.flatten(start_dim=2) for x in before_nms_features], dim=2
).permute(
    0, 2, 1
)  # [b,features,channels]
reg_feat_flatten = torch.cat(
    [x.flatten(start_dim=2) for x in before_nms_regf], dim=2
).permute(0, 2, 1)
features_cls, features_reg, cls_scores, fg_scores = find_feature_score(
            cls_feat_flatten, pred_idx, reg_feat_flatten, pred_result
        )
features_reg = features_reg.unsqueeze(0)
features_cls = features_cls.unsqueeze(0)
cls_scores = cls_scores.to(cls_feat_flatten.dtype)
fg_scores = fg_scores.to(cls_feat_flatten.dtype)

IndexError: index 5273 is out of bounds for dimension 0 with size 128

In [80]:
cls_feat_flatten.shape, reg_feat_flatten.shape

(torch.Size([256, 128, 64]), torch.Size([256, 64, 32]))