Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export, detect and validation with TensorRT engine file #5699

Merged
merged 19 commits into from Nov 22, 2021
Merged

Export, detect and validation with TensorRT engine file #5699

merged 19 commits into from Nov 22, 2021

Conversation

imyhxy
Copy link
Contributor

@imyhxy imyhxy commented Nov 18, 2021

Hi, there.

I have added support for exporting and detecting with TensorRT plan file to the yolov5. The requirements is that you should install tensorrt and pycuda python package.

Export :

python export.py --weights weights/yolov5s.pt --imgsz 384 640 --batch-size 1 --include engine --device 0  # export FP32 engine
python export.py --weights weights/yolov5s.pt --imgsz 384 640 --batch-size 1 --include engine --device 0 --half  # export FP16 engine
python export.py --weights weights/yolov5s.pt --imgsz 384 640 --batch-size 1 --include engine --device 0 --half --verbose  # print TensorRT building log

Output of export:

export: data=data/coco128.yaml, weights=weights/yolov5s.pt, imgsz=[384, 640], batch_size=1, device=0, half=False, inplace=False, train=False, optimize=False,
 int8=False, dynamic=False, simplify=False, verbose=False, opset=13, workspace=4, topk_per_class=100, topk_all=100, iou_thres=0.45, conf_thres=0.25, include=
['engine']                                                                                                                                                   
YOLOv5 πŸš€ v6.0-96-g7bf6ad0 torch 1.10.0a0+git36449ea CUDA:0 (NVIDIA GeForce RTX 2060, 5935MiB)                                                               
                                                                                                                                                             
Fusing layers...                                                                                                                                             
/home/fkwong/miniconda3/envs/edge/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be requir
ed to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2156.)                                                   
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]                                                                                       
Model Summary: 213 layers, 7225885 parameters, 0 gradients                                                                                                   
                                                                                                                                                             
PyTorch: starting from weights/yolov5s.pt (14.7 MB)                                                                                                          
                                                    
ONNX: starting export with onnx 1.10.1...
yolov5/models/yolo.py:58: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
ONNX: export success, saved as weights/yolov5s.onnx (29.2 MB)
ONNX: run --dynamic ONNX model inference with: 'python detect.py --weights weights/yolov5s.onnx'

TensorRT: starting export with TensorRT 8.2.0.6...
[11/18/2021-14:57:08] [TRT] [I] [MemUsageChange] Init CUDA: CPU +312, GPU +0, now: CPU 2678, GPU 1841 (MiB)
[11/18/2021-14:57:09] [TRT] [W] onnx2trt_utils.cpp:366: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[11/18/2021-14:57:09] [TRT] [W] onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[11/18/2021-14:57:09] [TRT] [W] onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[11/18/2021-14:57:09] [TRT] [W] onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
TensorRT: Network Description:
TensorRT:     input "images" with shape (1, 3, 384, 640) and dtype DataType.FLOAT
TensorRT:     output "output" with shape (1, 15120, 85) and dtype DataType.FLOAT
TensorRT:     output "350" with shape (1, 3, 48, 80, 85) and dtype DataType.FLOAT
TensorRT:     output "416" with shape (1, 3, 24, 40, 85) and dtype DataType.FLOAT
TensorRT:     output "482" with shape (1, 3, 12, 20, 85) and dtype DataType.FLOAT
TensorRT: building FP32 engine in weights/yolov5s.trt
export.py:312: DeprecationWarning: Use build_serialized_network instead.
  with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
[11/18/2021-14:57:09] [TRT] [I] [MemUsageSnapshot] Builder begin: CPU 2824 MiB, GPU 1867 MiB
[11/18/2021-14:57:09] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 2827, GPU 1875 (MiB)
[11/18/2021-14:57:09] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2827, GPU 1883 (MiB)
[11/18/2021-14:57:09] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[11/18/2021-14:57:49] [TRT] [I] Some tactics do not have sufficient workspace memory to run. Increasing workspace size may increase performance, please check verbose output.
[11/18/2021-14:58:40] [TRT] [I] [BlockAssignment] Algorithm Linear took 0.035196ms to assign 139 blocks to 139 nodes requiring 4444554755 bytes.
[11/18/2021-14:58:40] [TRT] [I] Total Activation Memory: 149587459
[11/18/2021-14:58:40] [TRT] [I] Detected 1 inputs and 7 output network tensors.
[11/18/2021-14:58:40] [TRT] [I] Total Host Persistent Memory: 139168
[11/18/2021-14:58:40] [TRT] [I] Total Device Persistent Memory: 16933888
[11/18/2021-14:58:40] [TRT] [I] Total Scratch Memory: 0
[11/18/2021-14:58:40] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 7 MiB, GPU 16 MiB
[11/18/2021-14:58:40] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 22.9321ms to assign 9 blocks to 138 nodes requiring 20582401 bytes.
[11/18/2021-14:58:40] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 2844, GPU 1933 (MiB)
[11/18/2021-14:58:40] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +10, now: CPU 2844, GPU 1943 (MiB)
[11/18/2021-14:58:41] [TRT] [I] [MemUsageSnapshot] Builder end: CPU 2843 MiB, GPU 1911 MiB
TensorRT: serializing engine to file: weights/yolov5s.trt

Export complete (96.73s)
Results saved to yolov5/weights
Visualize with https://netron.app

Detect:

python detect.py --weights weights/yolov5s.trt --source data/images/bus.jpg --view-img --imgsz 384 640

Output of detect:

detect: weights=['weights/yolov5s.trt'], source=data/images/bus.jpg, imgsz=[384, 640], conf_thres=0.25, iou_thres=0.45, max_det=1000, device=, view_img=True, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=runs/detect, name=exp, exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False
YOLOv5 πŸš€ v6.0-101-gb555690 torch 1.10.0a0+git36449ea CUDA:0 (NVIDIA GeForce RTX 2060, 5935MiB)

Loading weights/yolov5s.trt for TensorRT inference...
[11/18/2021-21:30:49] [TRT] [I] [MemUsageChange] Init CUDA: CPU +312, GPU +0, now: CPU 417, GPU 718 (MiB)
[11/18/2021-21:30:49] [TRT] [I] Loaded engine size: 31 MiB
[11/18/2021-21:30:49] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine begin: CPU 448 MiB, GPU 718 MiB
[11/18/2021-21:30:50] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +497, GPU +214, now: CPU 955, GPU 964 (MiB)
[11/18/2021-21:30:51] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +168, GPU +202, now: CPU 1123, GPU 1166 (MiB)
[11/18/2021-21:30:51] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine end: CPU 1123 MiB, GPU 1148 MiB
[11/18/2021-21:30:51] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation begin: CPU 1092 MiB, GPU 1164 MiB
[11/18/2021-21:30:51] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 1092, GPU 1172 (MiB)
[11/18/2021-21:30:51] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 1092, GPU 1180 (MiB)
[11/18/2021-21:30:51] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation end: CPU 1093 MiB, GPU 1216 MiB
image 1/1 /home/user/git/yolov5/data/images/bus.jpg: 384x640 3 class0s, 1 class5, Done. (0.035s)
Speed: 1.6ms pre-process, 34.8ms inference, 16.5ms NMS per image at shape (1, 3, 384, 640)
Results saved to runs/detect/exp47

bus

Note: you have to manually specify the image size (both width and height) in detection phase

πŸ› οΈ PR Summary

Made with ❀️ by Ultralytics Actions

🌟 Summary

Added support for NVIDIA TensorRT inference in YOLOv5.

πŸ“Š Key Changes

  • Introduced TensorRT .engine file support in models/common.py, export.py, and val.py.
  • Added capability to export models to TensorRT format from .onnx using the export_engine() function in export.py.
  • Adjusted half-precision (FP16) checks to include TensorRT alongside PyTorch in detect.py and val.py.
  • Included verbose logging and workspace size configuration as new TensorRT export options in export.py.

🎯 Purpose & Impact

  • Purpose: To enable faster inference on NVIDIA GPUs by leveraging TensorRT optimizations.
  • Impact: Users with compatible NVIDIA hardware can now benefit from reduced latency and improved performance in real-time applications. This update also broadens YOLOv5's compatibility, catering to a wider audience using NVIDIA's ecosystem. πŸš€πŸ–₯️

@imyhxy imyhxy closed this Nov 18, 2021
@imyhxy imyhxy reopened this Nov 18, 2021
@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 18, 2021

I just noticed that there are already some PRs discussing make val.py work with TensorRT engine file. I modified the val.py a little bit to make this PR's work compatiable with val.py.

Output of TensorRT validation:

val: data=data/coco.yaml, weights=['weights/yolov5s.trt'], batch_size=32, imgsz=640, conf_thres=0.001, iou_thres=0.6, task=val, device=0, single_cls=False, augment=False, verbose=False, save_txt=False, save_hybrid=False, save_conf=False, save_json=True, project=runs/val, name=exp, exist_ok=False, half=False, dnn=False
YOLOv5 πŸš€ v6.0-99-gc61bd93 torch 1.10.0a0+git36449ea CUDA:0 (NVIDIA GeForce RTX 2060, 5935MiB)

Loading weights/yolov5s.trt for TensorRT inference...
[11/18/2021-16:16:20] [TRT] [I] [MemUsageChange] Init CUDA: CPU +318, GPU +0, now: CPU 423, GPU 800 (MiB)
[11/18/2021-16:16:20] [TRT] [I] Loaded engine size: 36 MiB
[11/18/2021-16:16:20] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine begin: CPU 460 MiB, GPU 800 MiB
[11/18/2021-16:16:20] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +497, GPU +214, now: CPU 966, GPU 1050 (MiB)
[11/18/2021-16:16:20] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +169, GPU +204, now: CPU 1135, GPU 1254 (MiB)
[11/18/2021-16:16:20] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine end: CPU 1134 MiB, GPU 1236 MiB
[11/18/2021-16:16:20] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation begin: CPU 1098 MiB, GPU 1262 MiB
[11/18/2021-16:16:20] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 1098, GPU 1272 (MiB)
[11/18/2021-16:16:20] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 1098, GPU 1280 (MiB)
[11/18/2021-16:16:20] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation end: CPU 1098 MiB, GPU 1334 MiB
Forcing --batch-size 1 square inference shape(1,3,640,640) for non-PyTorch backends
val: Scanning '/datasets/11_mscoco/YOLO/val2017.cache' images and labels... 4952 found, 48 missing, 0 empty, 0 corrupted: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [00:00<?, ?it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [02:14<00:00, 37.18it/s]                                                                                                                                                                                    
                 all       5000      36335      0.658      0.505      0.552      0.358
Speed: 0.5ms pre-process, 10.2ms inference, 8.4ms NMS per image at shape (1, 3, 640, 640)

Evaluating pycocotools mAP... saving runs/val/exp10/yolov5s_predictions.json...
loading annotations into memory...
Done (t=0.48s)
creating index...
index created!
Loading and preparing results...
DONE (t=6.54s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=98.43s).
Accumulating evaluation results...
DONE (t=20.10s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.369
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.560
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.397
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.217
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.423
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.475
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.305
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.515
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.567
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.381
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.631
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.713

@imyhxy imyhxy changed the title Export and detect with TensorRT engine file Export, detect and validation with TensorRT engine file Nov 18, 2021
@glenn-jocher
Copy link
Member

@imyhxy awesome, this looks really promising! Yes there is another TRT PR but it did not insert inference code in DetectMultiBackend properly, this PR seems like it does. I will review today or tomorrow.

@glenn-jocher glenn-jocher mentioned this pull request Nov 18, 2021
1 task
@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 19, 2021

Hi, @glenn-jocher

I just looked into the mentioned pull request 5700, and update my implementation. Now, we get rid of the pycuda package and use pytorch for memory management. The number of memory copies is now reduced and the inference speed is increased. Check the following log for detail comparation.

Pytorch FP16 model inference with batch size 1:

Model Summary: 213 layers, 7225885 parameters, 0 gradients
val: Scanning '/home/user/datasets/11_mscoco/YOLO/val2017.cache' images and labels... 4952 found, 48 missing, 0 empty, 0 corrupted: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [00:00<?, ?it/s]
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [03:15<00:00, 25.59it/s]                                                             
                 all       5000      36335      0.668      0.505      0.555      0.359
Speed: 0.3ms pre-process, 22.0ms inference, 2.3ms NMS per image at shape (1, 3, 640, 640)

TensorRT FP16 inference with batch size 1:

val: data=data/coco.yaml, weights=['weights/yolov5s_fp16.engine'], batch_size=1, imgsz=640, conf_thres=0.001, iou_thres=0.6, task=val, device=, single_cls=False, augment=False, verbose=False, save_txt=False, save_hybrid=False, save_conf=False, save_json=True, project=runs/val, name=exp, exist_ok=False, half=True, dnn=False
YOLOv5 πŸš€ v6.0-104-g038e141 torch 1.10.0a0+git36449ea CUDA:0 (NVIDIA GeForce RTX 2060, 5935MiB)

Loading weights/yolov5s_fp16.engine for TensorRT inference...
[11/19/2021-16:13:10] [TRT] [I] [MemUsageChange] Init CUDA: CPU +320, GPU +0, now: CPU 413, GPU 493 (MiB)
[11/19/2021-16:13:10] [TRT] [I] Loaded engine size: 17 MiB
[11/19/2021-16:13:10] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine begin: CPU 431 MiB, GPU 493 MiB
[11/19/2021-16:13:10] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +496, GPU +217, now: CPU 936, GPU 728 (MiB)
[11/19/2021-16:13:11] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +169, GPU +203, now: CPU 1105, GPU 931 (MiB)
[11/19/2021-16:13:11] [TRT] [I] [MemUsageSnapshot] deserializeCudaEngine end: CPU 1105 MiB, GPU 913 MiB
[11/19/2021-16:13:12] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation begin: CPU 2219 MiB, GPU 1433 MiB
[11/19/2021-16:13:12] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 2219, GPU 1443 (MiB)
[11/19/2021-16:13:12] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 2219, GPU 1451 (MiB)
[11/19/2021-16:13:12] [TRT] [I] [MemUsageSnapshot] ExecutionContext creation end: CPU 2219 MiB, GPU 1485 MiB
val: Scanning '/home/user/datasets/11_mscoco/YOLO/val2017.cache' images and labels... 4952 found, 48 missing, 0 empty, 0 corrupted: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [00:00<?, ?it/s]
Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [01:43<00:00, 48.47it/s]
all       5000      36335      0.647      0.511      0.552      0.358
Speed: 0.4ms pre-process, 4.5ms inference, 2.2ms NMS per image at shape (1, 3, 640, 640)

@Auth0rM0rgan
Copy link

Hi @imyhxy, Well done! I have tried your code to export in TensorRT and it works fine. The model on TensorRT with half precision (16FP) is running almost 4x faster than native pytorch with FP16. I dont have that much experience on TensorRT. I wanted to ask you is it possible to export the model to TensorRT "int8" instead of FP16 and then load it with python?

Thanks!

@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 19, 2021

@Auth0rM0rgan yes, but INT8 quantization drop the mAP massivly with current YOLOv5 model, so I didn't implement that yet.

@glenn-jocher
Copy link
Member

@Auth0rM0rgan you see a 4x speedup with TRT export and inference compared with the base PyTorch GPU inference?

@glenn-jocher
Copy link
Member

@imyhxy awesome, thanks for the updates!

@Auth0rM0rgan
Copy link

Auth0rM0rgan commented Nov 19, 2021

@Auth0rM0rgan you see a 4x speedup with TRT export and inference compared with the base PyTorch GPU inference?

@glenn-jocher, Yes, I exported the yolov5m-objects365 and it speed up almost 4x with the same image-size and hyper parameters. It took ~0.008 for a frame with base PyTorch GPU, and 0.002 with TensorRT which is amazing!

btw, My GPU is 3090RTX

@glenn-jocher
Copy link
Member

@Auth0rM0rgan wow, thanks for confirming!

@imyhxy is this PR ready to merge?

@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 19, 2021

@glenn-jocher yes! I have done some tests, it works fine.

@glenn-jocher
Copy link
Member

Great! /rebase

@glenn-jocher
Copy link
Member

@imyhxy do you know how to install tensorrt in Colab? I tried running this PR but can not import tensorrt, even after pip install tensorrt and restart etc.

Screenshot 2021-11-19 at 22 21 15

@glenn-jocher
Copy link
Member

/rebase

@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 20, 2021

@glenn-jocher Hi, the tensorrt is not public to the pypiπŸ˜‚. It downloads from the Nvidia developer program. So you first need to register a Nvidia account and then search for TensorRT in the Nvidia download center. Following the instructions to download the *.tar.gz file. Uncompress it, the tensorrt package is under the python folder. Then you also need set the LD_LIBRARY_PATH to the <tensorrt_dir>/lib64. I don't sure the CUDA and cuDNN library is required or not. I will do some experience on colab and fell back to you later.

@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 20, 2021

@glenn-jocher Hi, I wrote a script to install TensorRT package on colab, I am not sure it works or not, because my account can only get a Tesla K80 GPU to run, whose compute power is too low for the TensorRT Runtime. Maybe you can help me test the script is working or not. Thanks. script

@glenn-jocher
Copy link
Member

@imyhxy thanks! I just requested access to the notebook.

@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 21, 2021

@glenn-jocher Morning here 🌞 My mistake, I don't set the share permission right. New link

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 21, 2021

@imyhxy thanks! I got the notebook working, but detect.py inference is very slow with yolov5s.engine (270ms). I think perhaps it's running on CPU? Is that possible? BTW I hosted the Colab TRT install file here to allow for a public notebook example:
https://ultralytics.com/assets/TensorRT-8.2.0.6.Linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz

Screenshot 2021-11-21 at 17 36 23

EDIT: Nevermind, I re-ran with a V100 Colab instance and speeds improved to 3 ms. Perhaps this was just an A100 issue.

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 21, 2021

@imyhxy is there a reason you read from the new ONNX buffer instead of reading from the ONNX file?

I see that there is a parse_from_file(self: tensorrt.tensorrt.OnnxParser, model: str) β†’ boolΒΆ method here:
https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/parsers/Onnx/pyOnnx.html

EDIT: If we read from the ONNX file we could simplify the PR by leaving export_onnx() function the way it is currently.

@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 22, 2021

@glenn-jocher Hi🌞

  1. The TensorRT plan file won't and can't be run on the CPU. The TensorRT is a NVIDIA Deep Learning GPU deploy library which requires a NVIDIA GPU to run.

  2. I see there are some version mismatched warning in your screenshot, so maybe that is the reason for slow inference. The TensorRT has a bunches of preset tactics which aim to perform different operations (Conv, Linear...) on different GPU architecture (Pascal, Tunning, Amper...). When building the TensorRT plan file, it benchmarks each operation of the network with their support tactics, and choose the one best for inference speed. The A100 is a GPU with Amper architecture, and the V100 is Tunning architecture. So the cuDNN library (one of the tactic source for TensorRT) maybe too old to have the optimized tactics for Amper architecture (not confirmed yet). I will update the shared colab to install newest cuDNN library today.

  3. The TensorRT plan file is designed to run on the same machine who generates it. If someone want to use the plan file on other machine, at least make sure the GPU and version of CUDA, cuDNN is the same (there are other affect factors: CPU, memory...)

  4. It's totally cool to read ONNX file from disk, the reason I don't do that just want to make the exporting of ONNX and TensorRT independent, and a little faster saving and loading speed.

  5. The build_engine and build_serialized_network is not equal, the build_serialized_network return a memory block that holding a serialized engine, so it can't be call the serialize() method. And I don't use new API because I was not confirmed it works on TensorRT 7 or not. I will check it today and fell back.

EDIT: I have checked that the build_serialized_network only on TensorRT 8, so I am not sure we should only compatiable to TensorRT 8 or not. If you want to remove the warning, just change build_engine to build_serialized_network and remove the following serialize() call.

EDIT: Update script. By the way, I am not sure hosting the TensorRT and cuDNN package is viloating the license or not.

@glenn-jocher
Copy link
Member

/rebase

@glenn-jocher
Copy link
Member

@imyhxy thanks, got it on all points! If you can confirm license problems then I should remove the hosting.

Everything else looks good, the only thing I noticed is the different handling in detect.py and val.py. Mainly this affects --half ops. In detect.py --half only applies to pt files, but in val.py we have pt or engine. Do the TRT models accept FP16 input images? Do the models need to be exported as FP16 models in this case? I'll run some tests on the PR and update here.

@imyhxy
Copy link
Contributor Author

imyhxy commented Nov 22, 2021

@glenn-jocher Hi, --half can be applied to both detect.py and val.py, and currently they both respect --half option:

half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA

And the TensorRT plan file do accept and only accept FP16 image when it was exported by export.py script with --half option. Otherwise the memory won't be align and will raise error by TensorRT runtime. So the export stage and detect/val stage should be consistent with the --half option.

For the license part, I am not a native English speaker, and make me hard to understand the license document. According to the following section, seem like that we can't just redistribution the original files. FYI, I also attach the license document here.

1.2. Distribution Requirements
These are the distribution requirements for you to exercise the distribution grant:
  1. Your application must have material additional functionality, beyond the included portions of the SDK.
  2. The distributable portions of the SDK shall only be accessed by your application.
  3. The following notice shall be included in modifications and derivative works of sample source code distributed: β€œThis software contains source code provided by NVIDIA Corporation.”
  4. Unless a developer tool is identified in this Agreement as distributable, it is delivered for your internal use only.
  5. The terms under which you distribute your application must be consistent with the terms of this Agreement, including (without limitation) terms relating to the license grant and license restrictions and protection of NVIDIA’s intellectual property rights. Additionally, you agree that you will protect the privacy, security and legal rights of your application users.
  6. You agree to notify NVIDIA in writing of any known or suspected distribution or use of the SDK not in compliance with the requirements of this Agreement, and to enforce the terms of your agreements with respect to distributed SDK.

cuDNN-SLA.pdf
TensorRT-SLA.pdf

@glenn-jocher
Copy link
Member

@imyhxy I'm not a native english speaker either, it's my second language after Spanish but I'm pretty good at it and I don't really understand the language there either! Probably best just to remove the hosting to stay safe, I'll update the Colab notebook appendix section I made with the official URL.

You are right about --half, I didn't note your detect.py updates!

@glenn-jocher glenn-jocher merged commit 7a39803 into ultralytics:master Nov 22, 2021
@glenn-jocher
Copy link
Member

@imyhxy PR is merged. Thank you for your contributions to YOLOv5 πŸš€ and Vision AI ⭐

@glenn-jocher
Copy link
Member

@imyhxy BTW I took your notebook and squeezed it into a single cell at the end of the Appendix section. All of YOLOv5 has only one notebook currently for everything: https://github.com/ultralytics/yolov5/blob/master/tutorial.ipynb

Screenshot 2021-11-22 at 14 59 38

@guyiyifeurach
Copy link

@imyhxy thanks! I got the notebook working, but detect.py inference is very slow with yolov5s.engine (270ms). I think perhaps it's running on CPU? Is that possible? BTW I hosted the Colab TRT install file here to allow for a public notebook example: https://ultralytics.com/assets/TensorRT-8.2.0.6.Linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz

Screenshot 2021-11-21 at 17 36 23

EDIT: Nevermind, I re-ran with a V100 Colab instance and speeds improved to 3 ms. Perhaps this was just an A100 issue.
@glenn-jocher
Hi, your inference time are all right. The reason is that inferencing with engine file, the first image's inference time is very big. And the second image time is normal. You can check the details of your image in the comment. The time of bus.jpg is 538ms while the zidane.jpg cost 2ms. And you mentioned about 269.9ms is the average time of the two images' inference time.
Btw, another reason is perhaps, the frequency of GPU is the largest? I run the model on Jetson Nano, and inference time is not stable because the power of Nano is not MAXN and the frequency of GPU is not the maximum. Perhaps this will not happen on GPU server, but I doubt it may be a reason. And anyone run model on Jetson series can pay attention to this point. It's my suggestion.
Thank you!

@wanghr323
Copy link

Hi, I was doing QAT using tensorRT's tool, And now After export onnx model I need extra nvinfer_plugin.so , How to enable it?

@imyhxy
Copy link
Contributor Author

imyhxy commented Jan 18, 2022

@wanghr323 Seems you are not set the TensorRT environment properly. Make sure your <TensorRT>/lib is include in the LD_LIBRARY_PATH environment variable.

@passerbythesun
Copy link
Contributor

passerbythesun commented Feb 17, 2022

@imyhxy When converting pytorch model to TensorRT engine, the --half option mainly effects via two lines below:

  1. im, model = im.half(), model.half() # to FP16
  2. config.set_flag(trt.BuilderFlag.FP16)

With line 1, generated TensorRT network will following intput/output binding layers as follows:

0 INPUT kHALF images 3x640x640
1 OUTPUT kHALF 350 3x80x80x85
2 OUTPUT kHALF 418 3x40x40x85
3 OUTPUT kHALF 486 3x20x20x85
4 OUTPUT kFLOAT output 25200x85

If without line 1, the result is:

0 INPUT kFLOAT images 3x640x640
1 OUTPUT kFLOAT 350 3x80x80x85
2 OUTPUT kFLOAT 416 3x40x40x85
3 OUTPUT kFLOAT 482 3x20x20x85
4 OUTPUT kFLOAT output 25200x85

I noticed model.half() and im.half() affects onnx exporting, which results in a differenct in layer precision in TensorRT network definition. But after building engine with FP16 flag, what's the difference between the generated engine files?

I don't know what TensorRT exactly do when BuilderFlag is set to FP16.

Thanks!

@imyhxy
Copy link
Contributor Author

imyhxy commented Feb 17, 2022

@passerbythesun Hi, there.

The line1 will only affect the dtype of the input and output node of the engine. As you can see, the engine requires and outputs half precision tensors (except the 'output' node, which is calculated by some ops don't support half precision) when line1 is enabled, but requires and outputs float precision tensors when it is disabled.

When line2 involved, the layers of the engine is always running in half precision (except the ops not support half precision).

You can add --verbose option to get a fine-grain engine build process which will print out the precision of each layer.

@ingbeeedd
Copy link

@imyhxy

Is not dependent on the tensorrt cudnn version when I run the tensorrt export, is it right?

@imyhxy
Copy link
Contributor Author

imyhxy commented Mar 29, 2022

@ingbeeedd
No, it depends on the version of TensorRT because some OPs is not supported by the legency TensorRT library. We have tested TensorRT >= 7.1. As for CUDNN, you need to install a library compatible with your TensorRT library, you can refer to the NVIDIA website.

BjarneKuehl pushed a commit to fhkiel-mlaip/yolov5 that referenced this pull request Aug 26, 2022
…5699)

* Export and detect with TensorRT engine file

* Resolve `isort`

* Make validation works with TensorRT engine

* feat: update export docstring

* feat: change suffix from *.trt to *.engine

* feat: get rid of pycuda

* feat: make compatiable with val.py

* feat: support detect with fp16 engine

* Add Lite to Edge TPU string

* Remove *.trt comment

* Revert to standard success logger.info string

* Fix Deprecation Warning

```
export.py:310: DeprecationWarning: Use build_serialized_network instead.
  with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
```

* Revert deprecation warning fix

@imyhxy it seems we can't apply the deprecation warning fix because then export fails, so I'm reverting my previous change here.

* Update export.py

* Update export.py

* Update common.py

* export onnx to file before building TensorRT engine file

* feat: triger ONNX export failed early

* feat: load ONNX model from file

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
@vivekdevre
Copy link

1.10.0a0+git36449ea

@imyhxy can you implement similar mAP calculation code for .trt models of "yolov7"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tensorrt 7.2.3.4 get some wrong boxes with high score,but tensort 8.2 is correct
8 participants