Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ultralytics 8.1.25 OpenVINO LATENCY and THROUGHPUT modes #8058

Merged
merged 36 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
318a343
Enable OpenVINO models throughput mode
glenn-jocher Feb 6, 2024
92e13e4
Auto-format by https://ultralytics.com/actions
UltralyticsAssistant Feb 6, 2024
1275028
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 7, 2024
80130bc
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 9, 2024
3710e79
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 12, 2024
82e7d59
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 12, 2024
53a94bb
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 13, 2024
5eb2778
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 17, 2024
f76ff04
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 20, 2024
8a8ec7d
Add throughput mode code
glenn-jocher Feb 20, 2024
5f52ed3
Add throughput mode code
glenn-jocher Feb 20, 2024
5477165
Add compile_model config arg
glenn-jocher Feb 20, 2024
7240c1c
Update dependency from 2023.0 to 2023.3
glenn-jocher Feb 20, 2024
ecc7d42
Update dependency from 2023.0 to 2023.3
glenn-jocher Feb 20, 2024
6b67d1c
Debug
glenn-jocher Feb 20, 2024
8bbd760
Simplify batch dim handling
glenn-jocher Feb 20, 2024
0ca45a3
Cleanup
glenn-jocher Feb 20, 2024
46b7263
Cleanup
glenn-jocher Feb 20, 2024
73e09f4
Cleanup
glenn-jocher Feb 20, 2024
48cd8f2
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 20, 2024
e9bae3d
Merge branch 'main' into ov-throughput-mode
glenn-jocher Feb 21, 2024
ae898e1
Merge branch 'main' into ov-throughput-mode
glenn-jocher Mar 5, 2024
11c1b3c
Update autobackend.py
glenn-jocher Mar 5, 2024
5919ac5
Merge branch 'main' into ov-throughput-mode
glenn-jocher Mar 5, 2024
5f26239
Remove mo import
glenn-jocher Mar 5, 2024
d867e38
Fix ov imports
glenn-jocher Mar 5, 2024
d0c095e
Update inference mode logic
glenn-jocher Mar 5, 2024
f9f8b1e
Add userdata input
glenn-jocher Mar 5, 2024
21778ad
Update ultralytics/nn/autobackend.py
glenn-jocher Mar 5, 2024
1be6b7c
Cleanup autobackend comments
glenn-jocher Mar 5, 2024
f19995e
Correct THROUGHPUT mode sort order
glenn-jocher Mar 5, 2024
0ac74bd
Merge branch 'main' into ov-throughput-mode
glenn-jocher Mar 5, 2024
8260d01
Merge branch 'main' into ov-throughput-mode
glenn-jocher Mar 5, 2024
aeb610e
Merge branch 'main' into ov-throughput-mode
glenn-jocher Mar 5, 2024
9c3f16c
Merge branch 'main' into ov-throughput-mode
glenn-jocher Mar 6, 2024
f1b4685
Merge branch 'main' into ov-throughput-mode
glenn-jocher Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ def export_onnx(self, prefix=colorstr("ONNX:")):
@try_export
def export_openvino(self, prefix=colorstr("OpenVINO:")):
"""YOLOv8 OpenVINO export."""
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
import openvino as ov # noqa
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino/
import openvino as ov

LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
Expand All @@ -433,7 +433,7 @@ def serialize(ov_model, file):
if self.model.task != "classify":
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])

ov.save_model(ov_model, file, compress_to_fp16=self.args.half)
ov.runtime.save_model(ov_model, file, compress_to_fp16=self.args.half)
yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml

if self.args.int8:
Expand Down
142 changes: 109 additions & 33 deletions ultralytics/nn/autobackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def __init__(
if not (pt or triton or nn_module):
w = attempt_download_asset(w)

# Load model
if nn_module: # in-memory PyTorch model
# In-memory PyTorch model
if nn_module:
model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model
if hasattr(model, "kpt_shape"):
Expand All @@ -146,7 +146,9 @@ def __init__(
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
pt = True
elif pt: # PyTorch

# PyTorch
elif pt:
from ultralytics.nn.tasks import attempt_load_weights

model = attempt_load_weights(
Expand All @@ -158,30 +160,38 @@ def __init__(
names = model.module.names if hasattr(model, "module") else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif jit: # TorchScript

# TorchScript
elif jit:
LOGGER.info(f"Loading {w} for TorchScript inference...")
extra_files = {"config.txt": ""} # model metadata
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
if extra_files["config.txt"]: # load metadata dict
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
elif dnn: # ONNX OpenCV DNN

# ONNX OpenCV DNN
elif dnn:
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
check_requirements("opencv-python>=4.5.4")
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime

# ONNX Runtime
elif onnx:
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
import onnxruntime

providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map # metadata
elif xml: # OpenVINO
metadata = session.get_modelmeta().custom_metadata_map

# OpenVINO
elif xml:
LOGGER.info(f"Loading {w} for OpenVINO inference...")
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
import openvino as ov # noqa
check_requirements("openvino>=2023.3")
import openvino as ov

core = ov.Core()
w = Path(w)
Expand All @@ -193,9 +203,18 @@ def __init__(
batch_dim = ov.get_batch(ov_model)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device

inference_mode = "LATENCY" # either 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
ov_compiled_model = core.compile_model(
ov_model,
device_name="AUTO", # AUTO selects best available device, do not modify
config={"PERFORMANCE_HINT": inference_mode},
)
input_name = ov_compiled_model.input().get_any_name()
metadata = w.parent / "metadata.yaml"
elif engine: # TensorRT

# TensorRT
elif engine:
LOGGER.info(f"Loading {w} for TensorRT inference...")
try:
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
Expand Down Expand Up @@ -234,20 +253,26 @@ def __init__(
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
elif coreml: # CoreML

# CoreML
elif coreml:
LOGGER.info(f"Loading {w} for CoreML inference...")
import coremltools as ct

model = ct.models.MLModel(w)
metadata = dict(model.user_defined_metadata)
elif saved_model: # TF SavedModel

# TF SavedModel
elif saved_model:
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
import tensorflow as tf

keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
metadata = Path(w) / "metadata.yaml"
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt

# TF GraphDef
elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
import tensorflow as tf

Expand All @@ -263,6 +288,8 @@ def wrap_frozen_graph(gd, inputs, outputs):
with open(w, "rb") as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))

# TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
Expand All @@ -287,9 +314,13 @@ def wrap_frozen_graph(gd, inputs, outputs):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
elif tfjs: # TF.js

# TF.js
elif tfjs:
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
elif paddle: # PaddlePaddle

# PaddlePaddle
elif paddle:
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
import paddle.inference as pdi # noqa
Expand All @@ -304,7 +335,9 @@ def wrap_frozen_graph(gd, inputs, outputs):
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
metadata = w.parents[1] / "metadata.yaml"
elif ncnn: # NCNN

# NCNN
elif ncnn:
LOGGER.info(f"Loading {w} for NCNN inference...")
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN
import ncnn as pyncnn
Expand All @@ -317,18 +350,21 @@ def wrap_frozen_graph(gd, inputs, outputs):
net.load_param(str(w))
net.load_model(str(w.with_suffix(".bin")))
metadata = w.parent / "metadata.yaml"
elif triton: # NVIDIA Triton Inference Server

# NVIDIA Triton Inference Server
elif triton:
check_requirements("tritonclient[all]")
from ultralytics.utils.triton import TritonRemoteModel

model = TritonRemoteModel(w)

# Any other format (unsupported)
else:
from ultralytics.engine.exporter import export_formats

raise TypeError(
f"model='{w}' is not a supported model format. "
"See https://docs.ultralytics.com/modes/predict for help."
f"\n\n{export_formats()}"
f"See https://docs.ultralytics.com/modes/predict for help.\n\n{export_formats()}"
)

# Load external metadata YAML
Expand Down Expand Up @@ -380,21 +416,51 @@ def forward(self, im, augment=False, visualize=False, embed=None):
if self.nhwc:
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)

if self.pt or self.nn_module: # PyTorch
# PyTorch
if self.pt or self.nn_module:
y = self.model(im, augment=augment, visualize=visualize, embed=embed)
elif self.jit: # TorchScript

# TorchScript
elif self.jit:
y = self.model(im)
elif self.dnn: # ONNX OpenCV DNN

# ONNX OpenCV DNN
elif self.dnn:
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)
y = self.net.forward()
elif self.onnx: # ONNX Runtime

# ONNX Runtime
elif self.onnx:
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
elif self.xml: # OpenVINO

# OpenVINO
elif self.xml:
im = im.cpu().numpy() # FP32
y = list(self.ov_compiled_model(im).values())
elif self.engine: # TensorRT

if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes
n = im.shape[0] # number of images in batch
results = [None] * n # preallocate list with None to match the number of images

def callback(request, userdata):
"""Places result in preallocated list using userdata index."""
results[userdata] = request.results

# Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model)
async_queue.set_callback(callback)
for i in range(n):
# Start async inference with userdata=i to specify the position in results list
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
async_queue.wait_all() # wait for all inference requests to complete
y = [list(r.values()) for r in results][0]

else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
y = list(self.ov_compiled_model(im).values())

# TensorRT
elif self.engine:
if self.dynamic and im.shape != self.bindings["images"].shape:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
Expand All @@ -407,7 +473,9 @@ def forward(self, im, augment=False, visualize=False, embed=None):
self.binding_addrs["images"] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML

# CoreML
elif self.coreml:
im = im[0].cpu().numpy()
im_pil = Image.fromarray((im * 255).astype("uint8"))
# im = im.resize((192, 320), Image.BILINEAR)
Expand All @@ -426,12 +494,16 @@ def forward(self, im, augment=False, visualize=False, embed=None):
y = list(y.values())
elif len(y) == 2: # segmentation model
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
elif self.paddle: # PaddlePaddle

# PaddlePaddle
elif self.paddle:
im = im.cpu().numpy().astype(np.float32)
self.input_handle.copy_from_cpu(im)
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
elif self.ncnn: # NCNN

# NCNN
elif self.ncnn:
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
ex = self.net.create_extractor()
input_names, output_names = self.net.input_names(), self.net.output_names()
Expand All @@ -441,10 +513,14 @@ def forward(self, im, augment=False, visualize=False, embed=None):
mat_out = self.pyncnn.Mat()
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif self.triton: # NVIDIA Triton Inference Server

# NVIDIA Triton Inference Server
elif self.triton:
im = im.cpu().numpy() # torch to numpy
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)

# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
else:
im = im.cpu().numpy()
if self.saved_model: # SavedModel
y = self.model(im, training=False) if self.keras else self.model(im)
Expand Down