In [1]:
import tensorrt as trt

In [None]:
trt_logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(trt_logger)

config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, (1 << 50))
config.set_flag(trt.BuilderFlag.FP16)
# config.set_flag(trt.BuilderFlag.INT8)
config.default_device_type = trt.DeviceType.GPU

flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)

parser = trt.OnnxParser(network, trt_logger)
path_onnx_model = "../../model.onnx"
with open(path_onnx_model, "rb") as f:
    if not parser.parse(f.read()):
        print(f"ERROR: Failed to parse the ONNX file {path_onnx_model}")
        for error in range(parser.num_errors):
            print(parser.get_error(error))

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for input in inputs:
    print(f"Model {input.name} shape: {input.shape} {input.dtype}")
    
for output in outputs:
    print(f"Model {output.name} shape: {output.shape} {output.dtype}")

engine_bytes = builder.build_serialized_network(network, config)
with open("../../model.trt", 'wb') as f:
    f.write(bytearray(engine_bytes))

Model images shape: (1, 3, 480, 640) DataType.FLOAT
Model logits shape: (1, 300, 80) DataType.FLOAT
Model pred_boxes shape: (1, 300, 4) DataType.FLOAT
Model onnx::MatMul_2779 shape: (1, 300, 256) DataType.FLOAT
Model 2817 shape: (1, 3, 300, 256) DataType.FLOAT
Model onnx::Gather_2831 shape: (1, 3, 300, 80) DataType.FLOAT
Model onnx::Gather_2824 shape: (1, 3, 300, 4) DataType.FLOAT
Model input.332 shape: (1, 256, 60, 80) DataType.FLOAT
Model input.424 shape: (1, 256, 30, 40) DataType.FLOAT
Model input.516 shape: (1, 256, 15, 20) DataType.FLOAT
Model reference_points_unact shape: (1, 300, 4) DataType.FLOAT
Model 1568 shape: (1, 300, 80) DataType.FLOAT
Model 1548 shape: (1, 300, 4) DataType.FLOAT
Model onnx::ReduceMax_1506 shape: (1, 6300, 80) DataType.FLOAT
Model onnx::GatherElements_1518 shape: (1, 6300, 4) DataType.FLOAT
