Skip to content

Commit

Permalink
Merge pull request opencv#24463 from dkurt:dnn_shared_nodes_fusion
Browse files Browse the repository at this point in the history
DNN graph fusion with shared nodes opencv#24463

### Pull Request Readiness Checklist

For now, nodes from matched pattern are removed during the matching process so if nodes are used in similar subgraph, they cannot be found.

required for opencv#24397

**Merge with extra**: opencv/opencv_extra#1115

A part from [model_name ](https://github.com/onnx/models/blob/main/vision/object_detection_segmentation/fcn/model/fcn-resnet101-11.onnx) with two Resize subgraphs with shared nodes:
![image](https://github.com/opencv/opencv/assets/25801568/611d89d9-12fb-4add-9218-13b10d2c086a)

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
  • Loading branch information
dkurt authored and thewoz committed Jan 4, 2024
1 parent 06451a5 commit 090747f
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 6 deletions.
50 changes: 46 additions & 4 deletions modules/dnn/src/graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,7 @@ void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int
inputsNames[i] = inpName;
}

// Remove matched nodes except the last one. Indices in ascending order are expected.
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds.back());
for (int i = matchedNodesIds.size() - 2; i >= 0; --i)
net->removeNode(matchedNodesIds[i]);

// Modify the last node to be a fused one.
node->setType(fusedNodeOp);
Expand All @@ -191,17 +188,62 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
{
int numNodes = net->getNumNodes();
std::vector<int> matchedNodesIds, targetNodesIds;
std::vector<int> nodesToRemove;
for (int j = 0; j < patterns.size(); ++j)
{
for (int i = 0; i < numNodes; ++i)
{
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
{
patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
// Remove matched nodes except the last one.
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
}
}
}

if (nodesToRemove.empty())
return;

// Collect reference counts for every node
std::vector<int> refcounts(net->getNumNodes(), 0);
std::map<std::string, int> nodeIds;

// Register node outputs.
// Every usage of one of the node's outputs should be counted.
for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) {
for (int i = 0; i < net->getNumOutputs(nodeId); ++i) {
std::string name = net->getOutputName(nodeId, i);
nodeIds[name] = nodeId;
}
}

for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) {
// Increase counters for node's inputs
auto node = net->getNode(nodeId);
for (int i = 0; i < node->getNumInputs(); ++i) {
std::string inpName = node->getInputName(i);
if (inpName.empty())
continue;
CV_Assert(nodeIds.find(inpName) != nodeIds.end());
refcounts[nodeIds[inpName]] += 1;
}
}

// Remove all fused nodes. Indices expected to be in descending order.
std::sort(nodesToRemove.begin(), nodesToRemove.end(), [](int a, int b) { return a > b; });
for (int nodeId : nodesToRemove) {
if (refcounts[nodeId] == 0) {
// Decrease references to node's inputs and remove node itself
auto node = net->getNode(nodeId);
for (int i = 0; i < node->getNumInputs(); ++i) {
std::string inpName = node->getInputName(i);
refcounts[nodeIds[inpName]] -= 1;
}
net->removeNode(nodeId);
refcounts[nodeId] = -1; // Same node cannot be removed twice
}
}
}

}} // namespace cv::dnn
28 changes: 28 additions & 0 deletions modules/dnn/src/onnx/onnx_graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,33 @@ class ResizeSubgraph2 : public ExtractScalesSubgraph
}
};

class ResizeSubgraph3 : public Subgraph
{
public:
ResizeSubgraph3() : Subgraph()
{
int shapeSrc = addNodeToMatch("");
int input = addNodeToMatch("");

int shape_h = addNodeToMatch("Shape", shapeSrc);
int shape_w = addNodeToMatch("Shape", shapeSrc);
int gather_h = addNodeToMatch("Gather", shape_h, addNodeToMatch("Constant"));
int gather_w = addNodeToMatch("Gather", shape_w, addNodeToMatch("Constant"));
int unsqueeze_h = addNodeToMatch("Unsqueeze", gather_h);
int unsqueeze_w = addNodeToMatch("Unsqueeze", gather_w);
int concat1 = addNodeToMatch("Concat", unsqueeze_h, unsqueeze_w);
int cast = addNodeToMatch("Cast", concat1);

int shape2 = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int concat2 = addNodeToMatch("Concat", slice, cast);
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);

setFusedNode("Upsample", input, shapeSrc);
}
};


class BatchNormalizationSubgraphBase : public Subgraph
{
public:
Expand Down Expand Up @@ -1207,6 +1234,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<UpsampleSubgraph>());
subgraphs.push_back(makePtr<ResizeSubgraph1>());
subgraphs.push_back(makePtr<ResizeSubgraph2>());
subgraphs.push_back(makePtr<ResizeSubgraph3>());
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());
Expand Down
10 changes: 8 additions & 2 deletions modules/dnn/test/test_onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class Test_ONNX_layers : public DNNTestLayer

void testONNXModels(const String& basename, const Extension ext = npy,
double l1 = 0, double lInf = 0, const bool useSoftmax = false,
bool checkNoFallbacks = true, int numInps = 1)
bool checkNoFallbacks = true, int numInps = 1,
bool testShapes = true)
{
String onnxmodel = _tf("models/" + basename + ".onnx", required);
std::vector<Mat> inps(numInps);
Expand All @@ -76,7 +77,8 @@ class Test_ONNX_layers : public DNNTestLayer
Net net = readNetFromONNX(onnxmodel);
ASSERT_FALSE(net.empty());

testInputShapes(net, inps);
if (testShapes)
testInputShapes(net, inps);

net.setPreferableBackend(backend);
net.setPreferableTarget(target);
Expand Down Expand Up @@ -248,6 +250,10 @@ TEST_P(Test_ONNX_layers, Gather_shared_indices) {
testONNXModels("gather_shared_indices", npy, 0, 0, false, false, 1);
}

TEST_P(Test_ONNX_layers, Two_resizes_with_shared_subgraphs) {
testONNXModels("two_resizes_with_shared_subgraphs", npy, 0, 0, false, false, 3, /*testShapes*/ false);
}

TEST_P(Test_ONNX_layers, Convolution3D)
{
if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16)
Expand Down

0 comments on commit 090747f

Please sign in to comment.