Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yolov8 pose #1502

Merged
merged 24 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions yolov8/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ else()
# tensorrt
include_directories(/home/lindsay/TensorRT-8.4.1.5/include)
link_directories(/home/lindsay/TensorRT-8.4.1.5/lib)
# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)
# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)


endif()
Expand All @@ -51,5 +51,9 @@ target_link_libraries(yolov8_det ${OpenCV_LIBS})
add_executable(yolov8_seg ${PROJECT_SOURCE_DIR}/yolov8_seg.cpp ${SRCS})
target_link_libraries(yolov8_seg nvinfer cudart myplugins ${OpenCV_LIBS})


add_executable(yolov8_pose ${PROJECT_SOURCE_DIR}/yolov8_pose.cpp ${SRCS})
target_link_libraries(yolov8_pose nvinfer cudart myplugins ${OpenCV_LIBS})

add_executable(yolov8_cls ${PROJECT_SOURCE_DIR}/yolov8_cls.cpp ${SRCS})
target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS})
target_link_libraries(yolov8_cls nvinfer cudart myplugins ${OpenCV_LIBS})
19 changes: 19 additions & 0 deletions yolov8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ sudo ./yolov8_cls -s yolov8n-cls.wts yolov8-cls.engine n
sudo ./yolov8_cls -d yolov8n-cls.engine ../samples
```


### Pose Estimation
```
cd {tensorrtx}/yolov8/
// update "kNumClass = 1" in config.h
mkdir build
cd build
cp {ultralytics}/ultralytics/yolov8-pose.wts {tensorrtx}/yolov8/build
cmake ..
make
sudo ./yolov8_pose -s [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6] // serialize model to plan file
sudo ./yolov8_pose -d [.engine] [image folder] [c/g] // deserialize and run inference, the images in [image folder] will be processed.
// For example yolov8-pose
sudo ./yolov8_pose -s yolov8n-pose.wts yolov8n-pose.engine n
sudo ./yolov8_pose -d yolov8n-pose.engine ../images c //cpu postprocess
sudo ./yolov8_pose -d yolov8n-pose.engine ../images g //gpu postprocess
```


4. optional, load and run the tensorrt model in python

```
Expand Down
2 changes: 1 addition & 1 deletion yolov8/include/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network,
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry,
int px_arry_num, bool is_segmentation);
int px_arry_num, bool is_segmentation, bool is_pose);
2 changes: 2 additions & 0 deletions yolov8/include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
const static char* kInputTensorName = "images";
const static char* kOutputTensorName = "output";
const static int kNumClass = 80;
const static int kNumberOfPoints = 17; // number of keypoints total
const static int kBatchSize = 1;
const static int kGpuId = 0;
const static int kInputH = 640;
const static int kInputW = 640;
const static float kNmsThresh = 0.45f;
const static float kConfThresh = 0.5f;
const static float kConfThreshKeypoints = 0.5f; // keypoints confidence
const static int kMaxInputImageSize = 3000 * 3000;
const static int kMaxNumOutputBbox = 1000;

Expand Down
4 changes: 4 additions & 0 deletions yolov8/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ nvinfer1::IHostMemory* buildEngineYolov8Cls(nvinfer1::IBuilder* builder, nvinfer
nvinfer1::IHostMemory* buildEngineYolov8Seg(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw,
int& max_channels);

nvinfer1::IHostMemory* buildEngineYolov8Pose(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config,
nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw,
int& max_channels);
25 changes: 16 additions & 9 deletions yolov8/include/postprocess.h
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
#pragma once

#include "types.h"
#include "NvInfer.h"
#include <opencv2/opencv.hpp>
#include "NvInfer.h"
#include "types.h"

cv::Rect get_rect(cv::Mat& img, float bbox[4]);

void nms(std::vector<Detection>& res, float *output, float conf_thresh, float nms_thresh = 0.5);
void nms(std::vector<Detection>& res, float* output, float conf_thresh, float nms_thresh = 0.5);

void batch_nms(std::vector<std::vector<Detection>>& batch_res, float* output, int batch_size, int output_size,
float conf_thresh, float nms_thresh = 0.5);

void batch_nms(std::vector<std::vector<Detection>>& batch_res, float *output, int batch_size, int output_size, float conf_thresh, float nms_thresh = 0.5);
void draw_bbox(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);

void draw_bbox(std::vector<cv::Mat> &img_batch, std::vector<std::vector<Detection>> &res_batch);
void draw_bbox_keypoints_line(std::vector<cv::Mat>& img_batch, std::vector<std::vector<Detection>>& res_batch);

void batch_process(std::vector<std::vector<Detection>> &res_batch, const float* decode_ptr_host, int batch_size, int bbox_element, const std::vector<cv::Mat>& img_batch);
void batch_process(std::vector<std::vector<Detection>>& res_batch, const float* decode_ptr_host, int batch_size,
int bbox_element, const std::vector<cv::Mat>& img_batch);

void process_decode_ptr_host(std::vector<Detection> &res, const float* decode_ptr_host, int bbox_element, cv::Mat& img, int count);
void process_decode_ptr_host(std::vector<Detection>& res, const float* decode_ptr_host, int bbox_element, cv::Mat& img,
int count);

void cuda_decode(float* predict, int num_bboxes, float confidence_threshold,float* parray,int max_objects, cudaStream_t stream);
void cuda_decode(float* predict, int num_bboxes, float confidence_threshold, float* parray, int max_objects,
cudaStream_t stream);

void cuda_nms(float* parray, float nms_threshold, int max_objects, cudaStream_t stream);

void draw_mask_bbox(cv::Mat& img, std::vector<Detection>& dets, std::vector<cv::Mat>& masks, std::unordered_map<int, std::string>& labels_map);
void draw_mask_bbox(cv::Mat& img, std::vector<Detection>& dets, std::vector<cv::Mat>& masks,
std::unordered_map<int, std::string>& labels_map);
14 changes: 8 additions & 6 deletions yolov8/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
#include "config.h"

struct alignas(float) Detection {
//center_x center_y w h
float bbox[4];
float conf; // bbox_conf * cls_conf
float class_id;
float mask[32];
//center_x center_y w h
float bbox[4];
float conf; // bbox_conf * cls_conf
float class_id;
float mask[32];
float keypoints[51]; // 17*3 keypoints
};

struct AffineMatrix {
float value[6];
};

const int bbox_element = sizeof(AffineMatrix) / sizeof(float)+1; // left, top, right, bottom, confidence, class, keepflag
const int bbox_element =
sizeof(AffineMatrix) / sizeof(float) + 1; // left, top, right, bottom, confidence, class, keepflag
92 changes: 70 additions & 22 deletions yolov8/plugin/yololayer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,26 @@ void read(const char*& buffer, T& val) {
}
} // namespace Tn

__device__ float sigmoid(float x) {
return 1.0f / (1.0f + exp(-x));
}

namespace nvinfer1 {
YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation,
const int* strides, int stridesLength) {
YoloLayerPlugin::YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth,
int netHeight, int maxOut, bool is_segmentation, bool is_pose, const int* strides,
int stridesLength) {

mClassCount = classCount;
mNumberofpoints = numberofpoints;
mConfthreshkeypoints = confthreshkeypoints;
mYoloV8NetWidth = netWidth;
mYoloV8netHeight = netHeight;
mMaxOutObject = maxOut;
mStridesLength = stridesLength;
mStrides = new int[stridesLength];
memcpy(mStrides, strides, stridesLength * sizeof(int));
is_segmentation_ = is_segmentation;
is_pose_ = is_pose;
}

YoloLayerPlugin::~YoloLayerPlugin() {
Expand All @@ -44,6 +53,8 @@ YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) {
using namespace Tn;
const char *d = reinterpret_cast<const char*>(data), *a = d;
read(d, mClassCount);
read(d, mNumberofpoints);
read(d, mConfthreshkeypoints);
read(d, mThreadCount);
read(d, mYoloV8NetWidth);
read(d, mYoloV8netHeight);
Expand All @@ -54,6 +65,7 @@ YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) {
read(d, mStrides[i]);
}
read(d, is_segmentation_);
read(d, is_pose_);

assert(d == a + length);
}
Expand All @@ -63,6 +75,8 @@ void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
using namespace Tn;
char *d = static_cast<char*>(buffer), *a = d;
write(d, mClassCount);
write(d, mNumberofpoints);
write(d, mConfthreshkeypoints);
write(d, mThreadCount);
write(d, mYoloV8NetWidth);
write(d, mYoloV8netHeight);
Expand All @@ -72,13 +86,15 @@ void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
write(d, mStrides[i]);
}
write(d, is_segmentation_);
write(d, is_pose_);

assert(d == a + getSerializationSize());
}

size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT {
return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mYoloV8netHeight) + sizeof(mYoloV8NetWidth) +
sizeof(mMaxOutObject) + sizeof(mStridesLength) + sizeof(int) * mStridesLength + sizeof(is_segmentation_);
return sizeof(mClassCount) + sizeof(mNumberofpoints) + sizeof(mConfthreshkeypoints) + sizeof(mThreadCount) +
sizeof(mYoloV8netHeight) + sizeof(mYoloV8NetWidth) + sizeof(mMaxOutObject) + sizeof(mStridesLength) +
sizeof(int) * mStridesLength + sizeof(is_segmentation_) + sizeof(is_pose_);
}

int YoloLayerPlugin::initialize() TRT_NOEXCEPT {
Expand Down Expand Up @@ -133,14 +149,14 @@ const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT {
}

void YoloLayerPlugin::destroy() TRT_NOEXCEPT {

delete this;
}

nvinfer1::IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT {

YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV8NetWidth, mYoloV8netHeight, mMaxOutObject,
is_segmentation_, mStrides, mStridesLength);
YoloLayerPlugin* p =
new YoloLayerPlugin(mClassCount, mNumberofpoints, mConfthreshkeypoints, mYoloV8NetWidth, mYoloV8netHeight,
mMaxOutObject, is_segmentation_, is_pose_, mStrides, mStridesLength);
p->setPluginNamespace(mPluginNamespace);
return p;
}
Expand All @@ -157,15 +173,15 @@ __device__ float Logist(float data) {
};

__global__ void CalDetection(const float* input, float* output, int numElements, int maxoutobject, const int grid_h,
int grid_w, const int stride, int classes, int outputElem, bool is_segmentation) {
int grid_w, const int stride, int classes, int nk, float confkeypoints, int outputElem,
bool is_segmentation, bool is_pose) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= numElements)
return;

const int N_kpts = nk;
int total_grid = grid_h * grid_w;
int info_len = 4 + classes;
if (is_segmentation)
info_len += 32;
int info_len = 4 + classes + (is_segmentation ? 32 : 0) + (is_pose ? N_kpts * 3 : 0);
int batchIdx = idx / total_grid;
int elemIdx = idx % total_grid;
const float* curInput = input + batchIdx * total_grid * info_len;
Expand Down Expand Up @@ -200,8 +216,36 @@ __global__ void CalDetection(const float* input, float* output, int numElements,
det->bbox[2] = (col + 0.5f + curInput[elemIdx + 2 * total_grid]) * stride;
det->bbox[3] = (row + 0.5f + curInput[elemIdx + 3 * total_grid]) * stride;

for (int k = 0; is_segmentation && k < 32; k++) {
det->mask[k] = curInput[elemIdx + (k + 4 + classes) * total_grid];
if (is_segmentation) {
for (int k = 0; k < 32; ++k) {
det->mask[k] = curInput[elemIdx + (4 + classes + k) * total_grid];
}
}

if (is_pose) {
for (int kpt = 0; kpt < N_kpts; kpt++) {
int kpt_x_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3) * total_grid;
int kpt_y_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 1) * total_grid;
int kpt_conf_idx = (4 + classes + (is_segmentation ? 32 : 0) + kpt * 3 + 2) * total_grid;

float kpt_confidence = sigmoid(curInput[elemIdx + kpt_conf_idx]);

float kpt_x = (curInput[elemIdx + kpt_x_idx] * 2.0 + col) * stride;
float kpt_y = (curInput[elemIdx + kpt_y_idx] * 2.0 + row) * stride;

bool is_within_bbox =
kpt_x >= det->bbox[0] && kpt_x <= det->bbox[2] && kpt_y >= det->bbox[1] && kpt_y <= det->bbox[3];

if (kpt_confidence < confkeypoints || !is_within_bbox) {
det->keypoints[kpt * 3] = -1;
det->keypoints[kpt * 3 + 1] = -1;
det->keypoints[kpt * 3 + 2] = -1;
} else {
det->keypoints[kpt * 3] = kpt_x;
det->keypoints[kpt * 3 + 1] = kpt_y;
det->keypoints[kpt * 3 + 2] = kpt_confidence;
}
}
}
}

Expand Down Expand Up @@ -230,8 +274,8 @@ void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cuda
mThreadCount = numElem;

CalDetection<<<(numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>(
inputs[i], output, numElem, mMaxOutObject, grid_h, grid_w, stride, mClassCount, outputElem,
is_segmentation_);
inputs[i], output, numElem, mMaxOutObject, grid_h, grid_w, stride, mClassCount, mNumberofpoints,
mConfthreshkeypoints, outputElem, is_segmentation_, is_pose_);
}
}

Expand Down Expand Up @@ -260,16 +304,20 @@ IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFi
assert(fc->nbFields == 1);
assert(strcmp(fc->fields[0].name, "combinedInfo") == 0);
const int* combinedInfo = static_cast<const int*>(fc->fields[0].data);
int netinfo_count = 5;
int netinfo_count = 8;
int class_count = combinedInfo[0];
int input_w = combinedInfo[1];
int input_h = combinedInfo[2];
int max_output_object_count = combinedInfo[3];
bool is_segmentation = combinedInfo[4];
int numberofpoints = combinedInfo[1];
float confthreshkeypoints = combinedInfo[2];
int input_w = combinedInfo[3];
int input_h = combinedInfo[4];
int max_output_object_count = combinedInfo[5];
bool is_segmentation = combinedInfo[6];
bool is_pose = combinedInfo[7];
const int* px_arry = combinedInfo + netinfo_count;
int px_arry_length = fc->fields[0].length - netinfo_count;
YoloLayerPlugin* obj = new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, is_segmentation,
px_arry, px_arry_length);
YoloLayerPlugin* obj =
new YoloLayerPlugin(class_count, numberofpoints, confthreshkeypoints, input_w, input_h,
max_output_object_count, is_segmentation, is_pose, px_arry, px_arry_length);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
Expand Down
7 changes: 5 additions & 2 deletions yolov8/plugin/yololayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
namespace nvinfer1 {
class API YoloLayerPlugin : public IPluginV2IOExt {
public:
YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const int* strides,
int stridesLength);
YoloLayerPlugin(int classCount, int numberofpoints, float confthreshkeypoints, int netWidth, int netHeight,
int maxOut, bool is_segmentation, bool is_pose, const int* strides, int stridesLength);

YoloLayerPlugin(const void* data, size_t length);
~YoloLayerPlugin();
Expand Down Expand Up @@ -68,10 +68,13 @@ class API YoloLayerPlugin : public IPluginV2IOExt {
int mThreadCount = 256;
const char* mPluginNamespace;
int mClassCount;
int mNumberofpoints;
float mConfthreshkeypoints;
int mYoloV8NetWidth;
int mYoloV8netHeight;
int mMaxOutObject;
bool is_segmentation_;
bool is_pose_;
int* mStrides;
int mStridesLength;
};
Expand Down
15 changes: 9 additions & 6 deletions yolov8/src/block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,21 @@ nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network,
std::vector<nvinfer1::IConcatenationLayer*> dets, const int* px_arry,
int px_arry_num, bool is_segmentation) {
int px_arry_num, bool is_segmentation, bool is_pose) {
auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");
const int netinfo_count = 5; // Assuming the first 5 elements are for netinfo as per existing code.
const int netinfo_count = 8; // Assuming the first 5 elements are for netinfo as per existing code.
const int total_count = netinfo_count + px_arry_num; // Total number of elements for netinfo and px_arry combined.

std::vector<int> combinedInfo(total_count);
// Fill in the first 5 elements as per existing netinfo.
combinedInfo[0] = kNumClass;
combinedInfo[1] = kInputW;
combinedInfo[2] = kInputH;
combinedInfo[3] = kMaxNumOutputBbox;
combinedInfo[4] = is_segmentation;
combinedInfo[1] = kNumberOfPoints;
combinedInfo[2] = kConfThreshKeypoints;
combinedInfo[3] = kInputW;
combinedInfo[4] = kInputH;
combinedInfo[5] = kMaxNumOutputBbox;
combinedInfo[6] = is_segmentation;
combinedInfo[7] = is_pose;

// Copy the contents of px_arry into the combinedInfo vector after the initial 5 elements.
std::copy(px_arry, px_arry + px_arry_num, combinedInfo.begin() + netinfo_count);
Expand Down
Loading
Loading