Skip to content

Commit

Permalink
Merge pull request opencv#24483 from dkurt:dnn_fusion_commutative_ops
Browse files Browse the repository at this point in the history
Commutative rules for DNN subgraphs fusion opencv#24483

### Pull Request Readiness Checklist

related: opencv#24463 (comment)

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
  • Loading branch information
dkurt authored and thewoz committed Jan 4, 2024
1 parent cf81fb8 commit 7a98c8a
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 228 deletions.
68 changes: 40 additions & 28 deletions modules/dnn/src/graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ int Subgraph::getInputNodeId(const Ptr<ImportGraphWrapper>& net,
}

bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds)
std::vector<int>& matchedNodesIds)
{
matchedNodesIds.clear();
targetNodesIds.clear();

std::queue<int> nodesToMatch;
std::queue<int> targetNodes;
std::vector<std::pair<int, int> > matchings;
matchings.reserve(nodes.size());
nodesToMatch.push(nodeId);
targetNodes.push(nodes.size() - 1);
while (!nodesToMatch.empty())
Expand All @@ -94,51 +94,63 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
nodesToMatch.pop();
targetNodes.pop();

if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) !=
matchedNodesIds.end())
if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair<int, int>& match){ return match.first == targetNodeId; }) !=
matchings.end())
continue;

// Empty placeholder matches with any input type
if (nodes[targetNodeId].empty()) {
matchings.push_back({targetNodeId, nodeToMatch});
continue;
}

const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
if (node->getType() != nodes[targetNodeId])
return false;
continue;

std::vector<int>& inputNodes = inputs[targetNodeId];
if (inputNodes.size() != node->getNumInputs())
return false;
continue;

bool isCommutative = net->isCommutativeOp(node->getType());

for (int j = 0; j < inputNodes.size(); ++j)
{
if (nodes[inputNodes[j]].empty() || node->getInputName(j).empty()) // Unknown input node type.
// Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
if (node->getInputName(j).empty())
continue;
nodeId = getInputNodeId(net, node, j);
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
if (inpNode->getType() != "Const" && inpNode->getType() != "Constant")
if (isCommutative)
{
for (int i = 0; i < inputNodes.size(); ++i)
{
nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[i]);
}
}
else
{
nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[j]);
}
else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant")
return false;
}
matchedNodesIds.push_back(nodeToMatch);
targetNodesIds.push_back(targetNodeId);
matchings.push_back({targetNodeId, nodeToMatch});
}
if (matchings.size() != nodes.size())
return false;

const int n = matchedNodesIds.size();
std::vector<std::pair<int, int> > elements(n);
for (int i = 0; i < n; ++i)
elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]);
std::sort(elements.begin(), elements.end());
for (int i = 0; i < n; ++i)
// Sort matched by pattern nodes order.
std::sort(matchings.begin(), matchings.end());
matchedNodesIds.resize(matchings.size());
for (int i = 0; i < matchings.size(); ++i)
{
matchedNodesIds[i] = elements[i].first;
targetNodesIds[i] = elements[i].second;
matchedNodesIds[i] = matchings[i].second;
}
return true;
}

void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
const std::vector<int>& targetNodesIds)
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds)
{
// Extract names of input nodes.
std::vector<std::string> inputsNames(fusedNodeInputs.size());
Expand All @@ -149,9 +161,9 @@ void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
{
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]);
std::vector<int>& inpIndices = inputs[targetNodesIds[j]];
std::vector<int>& inpIndices = inputs[j];

CV_Assert(node->getNumInputs() == inpIndices.size());
CV_Assert(inpIndices.empty() || node->getNumInputs() == inpIndices.size());
for (int k = 0; k < inpIndices.size(); ++k)
{
if (inpIndices[k] == fusedNodeInputs[i])
Expand Down Expand Up @@ -187,15 +199,15 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
const std::vector<Ptr<Subgraph> >& patterns)
{
int numNodes = net->getNumNodes();
std::vector<int> matchedNodesIds, targetNodesIds;
std::vector<int> matchedNodesIds;
std::vector<int> nodesToRemove;
for (int j = 0; j < patterns.size(); ++j)
{
for (int i = 0; i < numNodes; ++i)
{
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
if (patterns[j]->match(net, i, matchedNodesIds))
{
patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
patterns[j]->replace(net, matchedNodesIds);
// Remove matched nodes except the last one.
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
}
Expand Down
8 changes: 4 additions & 4 deletions modules/dnn/src/graph_simplifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class ImportGraphWrapper
virtual std::string getOutputName(int nodeId, int outId) const = 0;

virtual void removeNode(int idx) = 0;

virtual bool isCommutativeOp(const std::string& type) const = 0;
};

class Subgraph // Interface to match and replace subgraphs.
Expand Down Expand Up @@ -75,12 +77,10 @@ class Subgraph // Interface to match and replace subgraphs.
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds);
std::vector<int>& matchedNodesIds);

// Fuse matched subgraph.
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
const std::vector<int>& targetNodesIds);
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds);

virtual void finalize(const Ptr<ImportGraphWrapper>& net,
const Ptr<ImportNodeWrapper>& fusedNode,
Expand Down

0 comments on commit 7a98c8a

Please sign in to comment.