Skip to content

Commit

Permalink
ultralytics 8.1.25 OpenVINO LATENCY and THROUGHPUT modes (#8058)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Adrian Boguszewski <adekboguszewski@gmail.com>
  • Loading branch information
3 people committed Mar 6, 2024
1 parent 6da7c9f commit 9094394
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 36 deletions.
6 changes: 3 additions & 3 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ def export_onnx(self, prefix=colorstr("ONNX:")):
@try_export
def export_openvino(self, prefix=colorstr("OpenVINO:")):
"""YOLOv8 OpenVINO export."""
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
import openvino as ov # noqa
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino/
import openvino as ov

LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
Expand All @@ -433,7 +433,7 @@ def serialize(ov_model, file):
if self.model.task != "classify":
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])

ov.save_model(ov_model, file, compress_to_fp16=self.args.half)
ov.runtime.save_model(ov_model, file, compress_to_fp16=self.args.half)
yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml

if self.args.int8:
Expand Down
142 changes: 109 additions & 33 deletions ultralytics/nn/autobackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def __init__(
if not (pt or triton or nn_module):
w = attempt_download_asset(w)

# Load model
if nn_module: # in-memory PyTorch model
# In-memory PyTorch model
if nn_module:
model = weights.to(device)
model = model.fuse(verbose=verbose) if fuse else model
if hasattr(model, "kpt_shape"):
Expand All @@ -146,7 +146,9 @@ def __init__(
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
pt = True
elif pt: # PyTorch

# PyTorch
elif pt:
from ultralytics.nn.tasks import attempt_load_weights

model = attempt_load_weights(
Expand All @@ -158,30 +160,38 @@ def __init__(
names = model.module.names if hasattr(model, "module") else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif jit: # TorchScript

# TorchScript
elif jit:
LOGGER.info(f"Loading {w} for TorchScript inference...")
extra_files = {"config.txt": ""} # model metadata
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
if extra_files["config.txt"]: # load metadata dict
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
elif dnn: # ONNX OpenCV DNN

# ONNX OpenCV DNN
elif dnn:
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
check_requirements("opencv-python>=4.5.4")
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime

# ONNX Runtime
elif onnx:
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
import onnxruntime

providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map # metadata
elif xml: # OpenVINO
metadata = session.get_modelmeta().custom_metadata_map

# OpenVINO
elif xml:
LOGGER.info(f"Loading {w} for OpenVINO inference...")
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
import openvino as ov # noqa
check_requirements("openvino>=2023.3")
import openvino as ov

core = ov.Core()
w = Path(w)
Expand All @@ -193,9 +203,18 @@ def __init__(
batch_dim = ov.get_batch(ov_model)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device

inference_mode = "LATENCY" # either 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
ov_compiled_model = core.compile_model(
ov_model,
device_name="AUTO", # AUTO selects best available device, do not modify
config={"PERFORMANCE_HINT": inference_mode},
)
input_name = ov_compiled_model.input().get_any_name()
metadata = w.parent / "metadata.yaml"
elif engine: # TensorRT

# TensorRT
elif engine:
LOGGER.info(f"Loading {w} for TensorRT inference...")
try:
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
Expand Down Expand Up @@ -234,20 +253,26 @@ def __init__(
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
elif coreml: # CoreML

# CoreML
elif coreml:
LOGGER.info(f"Loading {w} for CoreML inference...")
import coremltools as ct

model = ct.models.MLModel(w)
metadata = dict(model.user_defined_metadata)
elif saved_model: # TF SavedModel

# TF SavedModel
elif saved_model:
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
import tensorflow as tf

keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
metadata = Path(w) / "metadata.yaml"
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt

# TF GraphDef
elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
import tensorflow as tf

Expand All @@ -263,6 +288,8 @@ def wrap_frozen_graph(gd, inputs, outputs):
with open(w, "rb") as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))

# TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
Expand All @@ -287,9 +314,13 @@ def wrap_frozen_graph(gd, inputs, outputs):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
elif tfjs: # TF.js

# TF.js
elif tfjs:
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
elif paddle: # PaddlePaddle

# PaddlePaddle
elif paddle:
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
import paddle.inference as pdi # noqa
Expand All @@ -304,7 +335,9 @@ def wrap_frozen_graph(gd, inputs, outputs):
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
metadata = w.parents[1] / "metadata.yaml"
elif ncnn: # NCNN

# NCNN
elif ncnn:
LOGGER.info(f"Loading {w} for NCNN inference...")
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN
import ncnn as pyncnn
Expand All @@ -317,18 +350,21 @@ def wrap_frozen_graph(gd, inputs, outputs):
net.load_param(str(w))
net.load_model(str(w.with_suffix(".bin")))
metadata = w.parent / "metadata.yaml"
elif triton: # NVIDIA Triton Inference Server

# NVIDIA Triton Inference Server
elif triton:
check_requirements("tritonclient[all]")
from ultralytics.utils.triton import TritonRemoteModel

model = TritonRemoteModel(w)

# Any other format (unsupported)
else:
from ultralytics.engine.exporter import export_formats

raise TypeError(
f"model='{w}' is not a supported model format. "
"See https://docs.ultralytics.com/modes/predict for help."
f"\n\n{export_formats()}"
f"See https://docs.ultralytics.com/modes/predict for help.\n\n{export_formats()}"
)

# Load external metadata YAML
Expand Down Expand Up @@ -380,21 +416,51 @@ def forward(self, im, augment=False, visualize=False, embed=None):
if self.nhwc:
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)

if self.pt or self.nn_module: # PyTorch
# PyTorch
if self.pt or self.nn_module:
y = self.model(im, augment=augment, visualize=visualize, embed=embed)
elif self.jit: # TorchScript

# TorchScript
elif self.jit:
y = self.model(im)
elif self.dnn: # ONNX OpenCV DNN

# ONNX OpenCV DNN
elif self.dnn:
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)
y = self.net.forward()
elif self.onnx: # ONNX Runtime

# ONNX Runtime
elif self.onnx:
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
elif self.xml: # OpenVINO

# OpenVINO
elif self.xml:
im = im.cpu().numpy() # FP32
y = list(self.ov_compiled_model(im).values())
elif self.engine: # TensorRT

if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes
n = im.shape[0] # number of images in batch
results = [None] * n # preallocate list with None to match the number of images

def callback(request, userdata):
"""Places result in preallocated list using userdata index."""
results[userdata] = request.results

# Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model)
async_queue.set_callback(callback)
for i in range(n):
# Start async inference with userdata=i to specify the position in results list
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
async_queue.wait_all() # wait for all inference requests to complete
y = [list(r.values()) for r in results][0]

else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
y = list(self.ov_compiled_model(im).values())

# TensorRT
elif self.engine:
if self.dynamic and im.shape != self.bindings["images"].shape:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
Expand All @@ -407,7 +473,9 @@ def forward(self, im, augment=False, visualize=False, embed=None):
self.binding_addrs["images"] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML

# CoreML
elif self.coreml:
im = im[0].cpu().numpy()
im_pil = Image.fromarray((im * 255).astype("uint8"))
# im = im.resize((192, 320), Image.BILINEAR)
Expand All @@ -426,12 +494,16 @@ def forward(self, im, augment=False, visualize=False, embed=None):
y = list(y.values())
elif len(y) == 2: # segmentation model
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
elif self.paddle: # PaddlePaddle

# PaddlePaddle
elif self.paddle:
im = im.cpu().numpy().astype(np.float32)
self.input_handle.copy_from_cpu(im)
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
elif self.ncnn: # NCNN

# NCNN
elif self.ncnn:
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
ex = self.net.create_extractor()
input_names, output_names = self.net.input_names(), self.net.output_names()
Expand All @@ -441,10 +513,14 @@ def forward(self, im, augment=False, visualize=False, embed=None):
mat_out = self.pyncnn.Mat()
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif self.triton: # NVIDIA Triton Inference Server

# NVIDIA Triton Inference Server
elif self.triton:
im = im.cpu().numpy() # torch to numpy
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)

# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
else:
im = im.cpu().numpy()
if self.saved_model: # SavedModel
y = self.model(im, training=False) if self.keras else self.model(im)
Expand Down

0 comments on commit 9094394

Please sign in to comment.