Skip to content

Commit

Permalink
yolov8 detection p6
Browse files Browse the repository at this point in the history
  • Loading branch information
lindsayshuo committed Mar 31, 2024
1 parent c86d808 commit a9daf44
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 15 deletions.
3 changes: 3 additions & 0 deletions yolov8/include/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ nvinfer1::ITensor& input, int ch, int k, int s, int p, std::string lname);
nvinfer1::IElementWiseLayer* C2F(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int c1, int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* C2(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights>& weightMap,
nvinfer1::ITensor& input, int c1, int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* SPPF(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int c1, int c2, int k, std::string lname);

Expand Down
3 changes: 3 additions & 0 deletions yolov8/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
nvinfer1::IHostMemory* buildEngineYolov8Det(nvinfer1::IBuilder* builder,
nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, int& max_channels);

nvinfer1::IHostMemory* buildEngineYolov8DetP6(nvinfer1::IBuilder* builder,
nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw, int& max_channels);

nvinfer1::IHostMemory* buildEngineYolov8Cls(nvinfer1::IBuilder* builder,
nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path, float& gd, float& gw);

Expand Down
31 changes: 31 additions & 0 deletions yolov8/src/block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,37 @@ nvinfer1::ITensor& input, int c1, int c2, int n, bool shortcut, float e, std::st
return conv2;
}

nvinfer1::IElementWiseLayer* C2(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights>& weightMap,
nvinfer1::ITensor& input, int c1, int c2, int n, bool shortcut, float e, std::string lname) {
assert(network != nullptr);
int hidden_channels = static_cast<int>(c2 * e);

// cv1 branch
nvinfer1::IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, input, 2 * hidden_channels, 1, 1, 0, lname + ".cv1");
nvinfer1::ITensor* cv1_out = conv1->getOutput(0);

// Split the output of cv1 into two tensors
nvinfer1::Dims dims = cv1_out->getDimensions();
nvinfer1::ISliceLayer* split1 = network->addSlice(*cv1_out, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{dims.d[0] / 2, dims.d[1], dims.d[2]}, nvinfer1::Dims3{1, 1, 1});
nvinfer1::ISliceLayer* split2 = network->addSlice(*cv1_out, nvinfer1::Dims3{dims.d[0] / 2, 0, 0}, nvinfer1::Dims3{dims.d[0] / 2, dims.d[1], dims.d[2]}, nvinfer1::Dims3{1, 1, 1});

// Create y1 bottleneck sequence
nvinfer1::ITensor* y1 = split1->getOutput(0);
for (int i = 0; i < n; ++i) {
auto* bottleneck_layer = bottleneck(network, weightMap, *y1, hidden_channels, hidden_channels, shortcut, 1.0, lname + ".m." + std::to_string(i));
y1 = bottleneck_layer->getOutput(0); // update 'y1' to be the output of the current bottleneck
}

// Concatenate y1 with the second split of cv1
nvinfer1::ITensor* concatInputs[2] = {y1, split2->getOutput(0)};
nvinfer1::IConcatenationLayer* cat = network->addConcatenation(concatInputs, 2);

// cv2 to produce the final output
nvinfer1::IElementWiseLayer* conv2 = convBnSiLU(network, weightMap, *cat->getOutput(0), c2, 1, 1, 0, lname + ".cv2");

return conv2;
}

nvinfer1::IElementWiseLayer* SPPF(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int c1, int c2, int k, std::string lname){
int c_ = c1 / 2;
Expand Down
Loading

0 comments on commit a9daf44

Please sign in to comment.