From 407a9057478d3deea0a9984af42162d21afa2bd2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 17 Dec 2021 14:59:46 +0100 Subject: [PATCH] Check TensorRT>=8.0.0 version (#6021) * Check TensorRT>=8.0.0 version * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- models/common.py | 5 +++-- utils/general.py | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/models/common.py b/models/common.py index cfecb20d2141..4fd608f4b3e2 100644 --- a/models/common.py +++ b/models/common.py @@ -21,8 +21,8 @@ from torch.cuda import amp from utils.datasets import exif_transpose, letterbox -from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible, - non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh) +from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path, + make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box from utils.torch_utils import copy_attr, time_sync @@ -328,6 +328,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False): elif engine: # TensorRT LOGGER.info(f'Loading {w} for TensorRT inference...') import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download + check_version(trt.__version__, '8.0.0', verbose=True) # version requirement Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) logger = trt.Logger(trt.Logger.INFO) with open(w, 'rb') as f, trt.Runtime(logger) as runtime: diff --git a/utils/general.py b/utils/general.py index 1da8a147510e..7ff397fb4caa 100755 --- a/utils/general.py +++ b/utils/general.py @@ -248,14 +248,16 @@ def check_python(minimum='3.6.2'): check_version(platform.python_version(), minimum, name='Python ', hard=True) -def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False): +def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False): # Check version vs. required version current, minimum = (pkg.parse_version(x) for x in (current, minimum)) result = (current == minimum) if pinned else (current >= minimum) # bool - if hard: # assert min requirements met - assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' - else: - return result + s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string + if hard: + assert result, s # assert min requirements met + if verbose and not result: + LOGGER.warning(s) + return result @try_except