Skip to content

Commit

Permalink
Merge pull request #134 from pipeless-ai/yolo-world
Browse files Browse the repository at this point in the history
Support several inference outputs for ONNX Runtime + YOLO World example
  • Loading branch information
miguelaeh committed Feb 23, 2024
2 parents 99676e4 + f15d183 commit 1d288f8
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/onnx-candy/post-process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import cv2

def hook(frame_data, _):
inference_results = frame_data["inference_output"]
inference_results = frame_data["inference_output"].get("output1", [])
candy_image = inference_results[0] # Remove batch axis
candy_image = np.clip(candy_image, 0, 255)
candy_image = candy_image.transpose(1,2,0).astype("uint8")
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx-yolo/post-process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def hook(frame_data, _):
model_output = frame_data['inference_output']
if len(model_output) > 0:
yolo_input_shape = (640, 640, 3) # h,w,c
boxes, scores, class_ids = postprocess_yolo(frame.shape, yolo_input_shape, model_output)
boxes, scores, class_ids = postprocess_yolo(frame.shape, yolo_input_shape, model_output.get("output0", []))
class_labels = [yolo_classes[id] for id in class_ids]
for i in range(len(boxes)):
draw_bbox(frame, boxes[i], class_labels[i], scores[i], color_palette[class_ids[i]])
Expand Down
100 changes: 100 additions & 0 deletions examples/yolo-world/post-process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import cv2
import numpy as np

def hook(frame_data, _):
frame = frame_data['original']
model_output = frame_data['inference_output']
if len(model_output) > 0:
yolo_input_shape = (640, 640, 3) # h,w,c
boxes, scores, class_ids = postprocess_yolo_world(frame.shape, yolo_input_shape, model_output)
class_labels = [yolo_classes[int(id)] for id in class_ids]
for i in range(len(boxes)):
draw_bbox(frame, boxes[i], class_labels[i], scores[i], color_palette[int(class_ids[i])])

frame_data['modified'] = frame

#################################################
# Util functions to make the hook more readable #
#################################################
yolo_classes = ['hard hat', 'gloves', 'protective boot', 'reflective vest', 'person']
color_palette = np.random.uniform(0, 255, size=(len(yolo_classes), 3))

def draw_bbox(image, box, label='', score=None, color=(255, 0, 255), txt_color=(255, 255, 255)):
lw = max(round(sum(image.shape) / 2 * 0.003), 2)
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
if label:
tf = max(lw - 1, 1) # font thickness
w, h = cv2.getTextSize(str(label), 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
if score is not None:
cv2.putText(image, f'{label} - {score}', (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
0, lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
else:
cv2.putText(image, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
0, lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)

def postprocess_yolo_world(original_frame_shape, resized_img_shape, output):
original_height, original_width, _ = original_frame_shape
resized_height, resized_width, _ = resized_img_shape

boxes = np.array(output['boxes'][0])
classes = np.array(output['labels'][0])
scores = np.array(output['scores'][0])

# Filter negative indexes
neg_indexes_classes = np.where(classes < 0)[0]
neg_indexes_scores = np.where(scores < 0)[0]
neg_indexes = np.concatenate((neg_indexes_classes, neg_indexes_scores))

mask = np.ones(classes.shape, dtype=bool)
mask[neg_indexes] = False

boxes = boxes[mask]
classes = classes[mask]
scores = scores[mask]

# arrays to accumulate the results
result_boxes = []
result_classes = []
result_scores = []

# Calculate the scaling factors for the bounding box coordinates
if original_height > original_width:
scale_factor = original_height / resized_height
else:
scale_factor = original_width / resized_width

# Resize the output boxes
for i, score in enumerate(scores):
if score < 0.05: # apply confidence threshold
continue
if not score < 1:
continue # Remove bad predictions that return a score of 1.0

x1, y1, x2, y2 = boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]

## Calculate the scaled coordinates of the bounding box
## the original image was padded to be square
if original_height > original_width:
# we added pad on the width
pad = (resized_width - original_width / scale_factor) // 2
x1 = int((x1 - pad) * scale_factor)
y1 = int(y1 * scale_factor)
x2 = int((x2 - pad) * scale_factor)
y2 = int(y2 * scale_factor)
else:
# we added pad on the height
pad = (resized_height - original_height / scale_factor) // 2
x1 = int(x1 * scale_factor)
y1 = int((y1 - pad) * scale_factor)
x2 = int(x2 * scale_factor)
y2 = int((y2 - pad) * scale_factor)

result_classes.append(classes[i])
result_scores.append(score)
result_boxes.append([x1, y1, x2, y2])

return result_boxes, result_scores, result_classes
50 changes: 50 additions & 0 deletions examples/yolo-world/pre-process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import cv2
import numpy as np

def is_cuda_available():
return cv2.cuda.getCudaEnabledDeviceCount() > 0

"""
Resize and pad image. Uses CUDA when available
"""
def resize_and_pad(frame, target_dim, pad_top, pad_bottom, pad_left, pad_right):
target_height, target_width = target_dim
if is_cuda_available():
# FIXME: due to the memory allocation here could be even slower than running on CPU. We must provide the frame from GPU memory to the hook
frame_gpu = cv2.cuda_GpuMat(frame)
resized_frame_gpu = cv2.cuda.resize(frame_gpu, (target_width, target_height), interpolation=cv2.INTER_CUBIC)
padded_frame_gpu = cv2.cuda.copyMakeBorder(resized_frame_gpu, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
result = padded_frame_gpu.download()
return result
else:
resized_frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC)
padded_frame = cv2.copyMakeBorder(resized_frame, pad_top, pad_bottom, pad_left, pad_right,
borderType=cv2.BORDER_CONSTANT, value=(0, 0, 0))
return padded_frame

def resize_with_padding(frame, target_dim):
target_height, target_width, _ = target_dim
frame_height, frame_width, _ = frame.shape

width_ratio = target_width / frame_width
height_ratio = target_height / frame_height
# Choose the minimum scaling factor to maintain aspect ratio
scale_factor = min(width_ratio, height_ratio)
# Calculate new dimensions after resizing
new_width = int(frame_width * scale_factor)
new_height = int(frame_height * scale_factor)
# Calculate padding dimensions
pad_width = (target_width - new_width) // 2
pad_height = (target_height - new_height) // 2

padded_image = resize_and_pad(frame, (new_height, new_width), pad_height, pad_height, pad_width, pad_width)
return padded_image

def hook(frame_data, _):
frame = frame_data["original"].view()
yolo_input_shape = (640, 640, 3) # h,w,c
frame = resize_with_padding(frame, yolo_input_shape)
frame = np.array(frame) / 255.0 # Normalize pixel values
frame = np.transpose(frame, axes=(2,0,1)) # Convert to c,h,w
inference_inputs = frame.astype("float32")
frame_data['inference_input'] = inference_inputs
7 changes: 7 additions & 0 deletions examples/yolo-world/process.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"runtime": "onnx",
"model_uri": "https://pipeless-public.s3.eu-west-3.amazonaws.com/yolow-l-ppe.onnx",
"inference_params": {
"execution_provider": "cpu"
}
}
2 changes: 1 addition & 1 deletion pipeless/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pipeless/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pipeless-ai"
version = "1.10.0"
version = "1.11.0"
edition = "2021"
authors = ["Miguel A. Cabrera Minagorri"]
description = "An open-source computer vision framework to build and deploy applications in minutes"
Expand Down
20 changes: 13 additions & 7 deletions pipeless/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ pub enum UserData {
Dictionary(Vec<(String, UserData)>),
}

pub enum InferenceOutput {
Default(ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>),
OnnxInferenceOutput(crate::stages::inference::onnx::OnnxInferenceOutput)
}

pub struct RgbFrame {
uuid: uuid::Uuid,
original: ndarray::Array3<u8>,
Expand All @@ -26,7 +31,8 @@ pub struct RgbFrame {
fps: u8,
input_ts: f64, // epoch in seconds
inference_input: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
inference_output: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
// We can convert the output into an arrayview since the user does not need to modify it and the inference runtimes returns a view, so we avoid a copy
inference_output: InferenceOutput,
pipeline_id: uuid::Uuid,
user_data: UserData,
frame_number: u64,
Expand All @@ -47,7 +53,7 @@ impl RgbFrame {
pts, dts, duration, fps,
input_ts,
inference_input: ndarray::ArrayBase::zeros(ndarray::IxDyn(&[0])),
inference_output: ndarray::ArrayBase::zeros(ndarray::IxDyn(&[0])),
inference_output: InferenceOutput::Default(ndarray::ArrayBase::zeros(ndarray::IxDyn(&[0]))),
pipeline_id,
user_data: UserData::Empty,
frame_number,
Expand All @@ -62,7 +68,7 @@ impl RgbFrame {
pts: u64, dts: u64, duration: u64,
fps: u8, input_ts: f64,
inference_input: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
inference_output: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>,
inference_output: InferenceOutput,
pipeline_id: &str,
user_data: UserData, frame_number: u64,
) -> Self {
Expand Down Expand Up @@ -122,13 +128,13 @@ impl RgbFrame {
pub fn get_inference_input(&self) -> &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
&self.inference_input
}
pub fn get_inference_output(&self) -> &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
pub fn get_inference_output(&self) -> &InferenceOutput{
&self.inference_output
}
pub fn set_inference_input(&mut self, input_data: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>) {
self.inference_input = input_data;
}
pub fn set_inference_output(&mut self, output_data: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>) {
pub fn set_inference_output(&mut self, output_data: InferenceOutput) {
self.inference_output = output_data;
}
pub fn get_pipeline_id(&self) -> &uuid::Uuid {
Expand Down Expand Up @@ -180,7 +186,7 @@ impl Frame {
Frame::RgbFrame(frame) => frame.get_inference_input()
}
}
pub fn get_inference_output(&self) -> &ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>> {
pub fn get_inference_output(&self) -> &InferenceOutput {
match self {
Frame::RgbFrame(frame) => frame.get_inference_output()
}
Expand All @@ -190,7 +196,7 @@ impl Frame {
Frame::RgbFrame(frame) => { frame.set_inference_input(input_data); },
}
}
pub fn set_inference_output(&mut self, output_data: ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>) {
pub fn set_inference_output(&mut self, output_data: InferenceOutput) {
match self {
Frame::RgbFrame(frame) => { frame.set_inference_output(output_data); },
}
Expand Down
2 changes: 2 additions & 0 deletions pipeless/src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ pub fn publish_new_frame_change_event_sync(
) {
let new_frame_event = Event::new_frame_change(frame);
// By using try_send frames are discarded when the channel is full
// However, note this does not produce a fluid output video. For that instead of discarding the frame
// we would need to send it to the output without processing it
if let Err(err) = bus_sender.try_send(new_frame_event) {
debug!("Discarding frame: {}", err);
}
Expand Down
45 changes: 41 additions & 4 deletions pipeless/src/stages/inference/onnx.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::collections::HashMap;

use log::{error, warn};
use ort;

use crate as pipeless;

pub type OnnxInferenceOutput = HashMap<String, ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>>;

pub struct OnnxSessionParams {
stage_name: String, // Name o the stage this session belongs to
execution_provider: String, //The user has to provide the execution provider
Expand Down Expand Up @@ -163,10 +167,43 @@ impl super::session::SessionTrait for OnnxSession {
match self.session.run_with_binding(&io_bindings) {
Ok(()) => {
let outputs = io_bindings.outputs().unwrap();
// TODO: iterate over the outputs hashmap to return all the model outputs not just the first
let output = outputs[&self.session.outputs[0].name].try_extract().unwrap();
let output_ndarray = output.view().to_owned();
frame.set_inference_output(output_ndarray);
let mut frame_inference_output = OnnxInferenceOutput::new();
for (output_name, output_value) in outputs {
// FIXME: the extract code is very unelegant. The extract can return several different numric types depending on the model used
// and there is not a number wrapper that we can apply, so we have to check type by type
match output_value.try_extract() {
Ok(output) => {
//let output = output.view().map(|v: &_| v.into());
// FIXME: we can use an arrayview for the inference output instead of owned array base to avoid copying here.
let output_ndarray = output.view().to_owned();
frame_inference_output.insert(output_name, output_ndarray);
},
Err(_err) => {
// Try to convert from i64 since sometimes the models do not return floats
match output_value.try_extract() {
Ok(output) => {
// FIXME: this copies the array twice, first to_owned and then the mapv
let output_ndarray: ndarray::ArrayBase<ndarray::OwnedRepr<i64>, _> = output.view().to_owned();
let float_output = output_ndarray.mapv(|v| v as f32);
frame_inference_output.insert(output_name, float_output);
}
Err(_err) => {
// Try to convert from i64 since sometimes the models do not return floats
match output_value.try_extract() {
Ok(output) => {
// FIXME: this copies the array twice, first to_owned and then the mapv
let output_ndarray: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, _> = output.view().to_owned();
let float_output = output_ndarray.mapv(|v| v as f32);
frame_inference_output.insert(output_name, float_output);
}
Err(err) => warn!("Error extracting inference results: {}", err.to_string()),
}
},
}
}
}
}
frame.set_inference_output(pipeless::data::InferenceOutput::OnnxInferenceOutput(frame_inference_output));
},
Err(err) => error!("There was an error running inference: {}", err)
}
Expand Down
Loading

0 comments on commit 1d288f8

Please sign in to comment.