Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ultralytics 8.1.46 add TensorRT 10 support #9516

Merged
merged 41 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0995ecd
add support for TensorRT 10
Burhan-Q Apr 2, 2024
4a4c68f
add TensorRT10 compatability
Burhan-Q Apr 2, 2024
079aa43
Auto-format by https://ultralytics.com/actions
UltralyticsAssistant Apr 2, 2024
6850472
Merge branch 'main' into trt10
glenn-jocher Apr 3, 2024
70c2f6b
Merge branch 'main' into trt10
glenn-jocher Apr 6, 2024
bc21f77
Merge branch 'main' into trt10
glenn-jocher Apr 6, 2024
02c4b99
Merge branch 'main' into trt10
glenn-jocher Apr 6, 2024
94ea376
Merge branch 'main' into trt10
glenn-jocher Apr 7, 2024
8418063
Update exporter.py
glenn-jocher Apr 7, 2024
32b2445
Updates
glenn-jocher Apr 7, 2024
1f553c2
Add support to exporting or inference with TensorRT 10.0.0b6, and fix…
ZouJiu1 Apr 7, 2024
d0a088d
Update __init__.py
glenn-jocher Apr 7, 2024
a090b75
Merge branch 'main' into trt10
glenn-jocher Apr 7, 2024
4c40158
Merge branch 'main' into trt10
glenn-jocher Apr 8, 2024
574a3a5
Removed changes from 1f553c2 due to errors, now working and fixes iss…
Burhan-Q Apr 8, 2024
0964835
Auto-format by https://ultralytics.com/actions
UltralyticsAssistant Apr 8, 2024
905fd05
ensure max shape dimension scales no smaller than 1
Burhan-Q Apr 8, 2024
139c2da
Merge branch 'trt10' of https://github.com/ultralytics/ultralytics in…
Burhan-Q Apr 8, 2024
c1c0046
Merge branch 'main' into trt10
glenn-jocher Apr 8, 2024
0a9a92f
Merge branch 'main' into trt10
glenn-jocher Apr 9, 2024
4e81132
Updates
glenn-jocher Apr 9, 2024
a38ba95
Merge remote-tracking branch 'origin/trt10' into trt10
glenn-jocher Apr 9, 2024
0cb3053
Merge branch 'main' into trt10
glenn-jocher Apr 9, 2024
2c03f83
Refactor constants out of for loop for speed
glenn-jocher Apr 9, 2024
d371949
Update exporter.py
glenn-jocher Apr 9, 2024
5a5e61f
refactored to use unified methods
Burhan-Q Apr 9, 2024
df6a3ff
Merge branch 'trt10' of https://github.com/ultralytics/ultralytics in…
Burhan-Q Apr 9, 2024
050b278
remove redundant line
Burhan-Q Apr 9, 2024
cd18b7b
remove self. as assigned automatically
glenn-jocher Apr 9, 2024
3017bad
Add TensorRT export test
glenn-jocher Apr 9, 2024
0e694cb
Merge remote-tracking branch 'origin/trt10' into trt10
glenn-jocher Apr 9, 2024
d263946
Update test_cuda.py
glenn-jocher Apr 9, 2024
e221ff5
refactored to include older interface for `tensorrt<8.6`
Burhan-Q Apr 9, 2024
3bab6b4
Merge branch 'trt10' of https://github.com/ultralytics/ultralytics in…
Burhan-Q Apr 9, 2024
529293f
refactor to include path for installs with `tensorrt<8.6`
Burhan-Q Apr 9, 2024
117a011
fix dynamic where self. is needed, align else, rearrange flow
Burhan-Q Apr 9, 2024
1f49b5d
correct for error with TensorRT 8.4.3.1 without method `get_tensor_sh…
Burhan-Q Apr 9, 2024
1ef7818
delete redundant trt10 definition
glenn-jocher Apr 10, 2024
9b96f5b
Align variable names in Exporter and Autobackend
glenn-jocher Apr 10, 2024
ba20133
Merge branch 'main' into trt10
glenn-jocher Apr 10, 2024
1fba2f3
move test to slow
glenn-jocher Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def test_checks():
assert torch.cuda.is_available() == CUDA_IS_AVAILABLE
assert torch.cuda.device_count() == CUDA_DEVICE_COUNT

@pytest.mark.slow
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
def test_export_engine():
"""Test exporting the YOLO model to NVIDIA TensorRT format."""
f = YOLO(MODEL).export(format="engine", device=0)
YOLO(f)(BUS, device=0)


@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
def test_train():
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = "8.1.45"
__version__ = "8.1.46"

from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
Expand Down
29 changes: 18 additions & 11 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def export_coreml(self, prefix=colorstr("CoreML:")):
def export_engine(self, prefix=colorstr("TensorRT:")):
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
self.args.simplify = True
f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016

try:
Expand All @@ -666,12 +667,10 @@ def export_engine(self, prefix=colorstr("TensorRT:")):
if LINUX:
check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
import tensorrt as trt # noqa

check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0

self.args.simplify = True

LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
f = self.file.with_suffix(".engine") # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
Expand All @@ -680,7 +679,11 @@ def export_engine(self, prefix=colorstr("TensorRT:")):

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = int(self.args.workspace * (1 << 30))
workspace = int(self.args.workspace * (1 << 30))
if is_trt10:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
else: # TensorRT versions 7, 8
config.max_workspace_size = workspace
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
Expand All @@ -699,27 +702,31 @@ def export_engine(self, prefix=colorstr("TensorRT:")):
if shape[0] <= 1:
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
profile = builder.create_optimization_profile()
min_shape = (1, shape[1], 32, 32) # minimum input shape
opt_shape = (max(1, shape[0] // 2), *shape[1:]) # optimal input shape
max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape
for inp in inputs:
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)

LOGGER.info(
f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
)
if builder.platform_has_fast_fp16 and self.args.half:
half = builder.platform_has_fast_fp16 and self.args.half
LOGGER.info(f"{prefix} building FP{16 if half else 32} engine as {f}")
if half:
config.set_flag(trt.BuilderFlag.FP16)

# Free CUDA memory
del self.model
torch.cuda.empty_cache()

# Write file
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
build = builder.build_serialized_network if is_trt10 else builder.build_engine
with build(network, config) as engine, open(f, "wb") as t:
# Metadata
meta = json.dumps(self.metadata)
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
t.write(meta.encode())
# Model
t.write(engine.serialize())
t.write(engine if is_trt10 else engine.serialize())

return f, None

Expand Down
71 changes: 51 additions & 20 deletions ultralytics/nn/autobackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,23 +234,47 @@ def __init__(
meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
model = runtime.deserialize_cuda_engine(f.read()) # read engine
context = model.create_execution_context()

# Model context
try:
context = model.create_execution_context()
except Exception as e: # model is None
LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n")
raise e

bindings = OrderedDict()
output_names = []
fp16 = False # default updated below
dynamic = False
for i in range(model.num_bindings):
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
else: # output
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
is_trt10 = not hasattr(model, "num_bindings")
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
for i in num:
if is_trt10:
name = model.get_tensor_name(i)
dtype = trt.nptype(model.get_tensor_dtype(name))
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
if is_input:
if -1 in tuple(model.get_tensor_shape(name)):
dynamic = True
context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
if dtype == np.float16:
fp16 = True
else:
output_names.append(name)
shape = tuple(context.get_tensor_shape(name))
else: # TensorRT < 10.0
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
is_input = model.binding_is_input(i)
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1]))
if dtype == np.float16:
fp16 = True
else:
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
Expand Down Expand Up @@ -463,13 +487,20 @@ def callback(request, userdata):

# TensorRT
elif self.engine:
if self.dynamic and im.shape != self.bindings["images"].shape:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
if self.dynamic or im.shape != self.bindings["images"].shape:
if self.is_trt10:
self.context.set_input_shape("images", im.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))
else:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))

s = self.bindings["images"].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs["images"] = int(im.data_ptr())
Expand Down