Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push filter down cross join #5473

Merged
merged 3 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/graph/executor/algo/CartesianProductExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "graph/executor/algo/CartesianProductExecutor.h"

#include "graph/planner/plan/Algo.h"
#include "graph/planner/plan/Query.h"

namespace nebula {
namespace graph {
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ nebula_add_library(
OptGroup.cpp
OptRule.cpp
OptContext.cpp
rule/PushFilterDownCrossJoinRule.cpp
rule/PushFilterDownGetNbrsRule.cpp
rule/RemoveNoopProjectRule.cpp
rule/CombineFilterRule.cpp
Expand Down
139 changes: 139 additions & 0 deletions src/graph/optimizer/rule/PushFilterDownCrossJoinRule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#include "graph/optimizer/rule/PushFilterDownCrossJoinRule.h"

#include "graph/optimizer/OptContext.h"
#include "graph/optimizer/OptGroup.h"
#include "graph/planner/plan/PlanNode.h"
#include "graph/planner/plan/Query.h"
#include "graph/util/ExpressionUtils.h"

using nebula::graph::CrossJoin;
using nebula::graph::ExpressionUtils;
using nebula::graph::Filter;
using nebula::graph::PlanNode;
using nebula::graph::QueryContext;

namespace nebula {
namespace opt {

std::unique_ptr<OptRule> PushFilterDownCrossJoinRule::kInstance =
std::unique_ptr<PushFilterDownCrossJoinRule>(new PushFilterDownCrossJoinRule());

PushFilterDownCrossJoinRule::PushFilterDownCrossJoinRule() {
RuleSet::QueryRules().addRule(this);
}

const Pattern& PushFilterDownCrossJoinRule::pattern() const {
static Pattern pattern = Pattern::create(
PlanNode::Kind::kFilter,
{Pattern::create(
PlanNode::Kind::kCrossJoin,
{Pattern::create(PlanNode::Kind::kUnknown), Pattern::create(PlanNode::Kind::kUnknown)})});
return pattern;
}

StatusOr<OptRule::TransformResult> PushFilterDownCrossJoinRule::transform(
OptContext* octx, const MatchedResult& matched) const {
auto* filterGroupNode = matched.node;
auto* oldFilterNode = filterGroupNode->node();
DCHECK_EQ(oldFilterNode->kind(), PlanNode::Kind::kFilter);

auto* crossJoinNode = matched.planNode({0, 0});
DCHECK_EQ(crossJoinNode->kind(), PlanNode::Kind::kCrossJoin);
auto* oldCrossJoinNode = static_cast<const CrossJoin*>(crossJoinNode);

const auto* condition = static_cast<Filter*>(oldFilterNode)->condition();
DCHECK(condition);

const auto& leftResult = matched.result({0, 0, 0});
const auto& rightResult = matched.result({0, 0, 1});

Expression *leftFilterUnpicked = nullptr, *rightFilterUnpicked = nullptr;
OptGroup* leftGroup = pushFilterDownChild(octx, leftResult, condition, &leftFilterUnpicked);
OptGroup* rightGroup =
pushFilterDownChild(octx, rightResult, leftFilterUnpicked, &rightFilterUnpicked);

if (!leftGroup && !rightGroup) {
return TransformResult::noTransform();
}

leftGroup = leftGroup ? leftGroup : const_cast<OptGroup*>(leftResult.node->group());
rightGroup = rightGroup ? rightGroup : const_cast<OptGroup*>(rightResult.node->group());

// produce new CrossJoin node
auto* newCrossJoinNode = static_cast<CrossJoin*>(oldCrossJoinNode->clone());
auto newJoinGroup = rightFilterUnpicked ? OptGroup::create(octx) : filterGroupNode->group();
// TODO(yee): it's too tricky
auto newGroupNode = rightFilterUnpicked
? const_cast<OptGroup*>(newJoinGroup)->makeGroupNode(newCrossJoinNode)
: OptGroupNode::create(octx, newCrossJoinNode, newJoinGroup);
newGroupNode->dependsOn(leftGroup);
newGroupNode->dependsOn(rightGroup);
newCrossJoinNode->setLeftVar(leftGroup->outputVar());
newCrossJoinNode->setRightVar(rightGroup->outputVar());

if (rightFilterUnpicked) {
auto newFilterNode = Filter::make(octx->qctx(), nullptr, rightFilterUnpicked);
newFilterNode->setOutputVar(oldFilterNode->outputVar());
newFilterNode->setColNames(oldFilterNode->colNames());
newFilterNode->setInputVar(newCrossJoinNode->outputVar());
newGroupNode = OptGroupNode::create(octx, newFilterNode, filterGroupNode->group());
newGroupNode->dependsOn(const_cast<OptGroup*>(newJoinGroup));
} else {
newCrossJoinNode->setOutputVar(oldFilterNode->outputVar());
newCrossJoinNode->setColNames(oldCrossJoinNode->colNames());
}

TransformResult result;
result.eraseAll = true;
result.newGroupNodes.emplace_back(newGroupNode);
return result;
}

OptGroup* PushFilterDownCrossJoinRule::pushFilterDownChild(OptContext* octx,
const MatchedResult& child,
const Expression* condition,
Expression** unpickedFilter) {
if (!condition) return nullptr;

const auto* childPlanNode = DCHECK_NOTNULL(child.node->node());
const auto& colNames = childPlanNode->colNames();

// split the `condition` based on whether the varPropExpr comes from the left child
auto picker = [&colNames](const Expression* e) -> bool {
return ExpressionUtils::checkColName(colNames, e);
};

Expression* filterPicked = nullptr;
ExpressionUtils::splitFilter(condition, picker, &filterPicked, unpickedFilter);
if (!filterPicked) return nullptr;

auto* newChildPlanNode = childPlanNode->clone();
DCHECK_NE(childPlanNode->outputVar(), newChildPlanNode->outputVar());
newChildPlanNode->setInputVar(childPlanNode->inputVar());
newChildPlanNode->setColNames(childPlanNode->colNames());
auto* newChildGroup = OptGroup::create(octx);
auto* newChildGroupNode = newChildGroup->makeGroupNode(newChildPlanNode);
for (auto* g : child.node->dependencies()) {
newChildGroupNode->dependsOn(g);
}

auto* newFilterNode = Filter::make(octx->qctx(), nullptr, filterPicked);
newFilterNode->setOutputVar(childPlanNode->outputVar());
newFilterNode->setColNames(colNames);
newFilterNode->setInputVar(newChildPlanNode->outputVar());
auto* group = OptGroup::create(octx);
group->makeGroupNode(newFilterNode)->dependsOn(newChildGroup);
return group;
}

std::string PushFilterDownCrossJoinRule::toString() const {
return "PushFilterDownCrossJoinRule";
}

} // namespace opt
} // namespace nebula
37 changes: 37 additions & 0 deletions src/graph/optimizer/rule/PushFilterDownCrossJoinRule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#ifndef GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_
#define GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_

#include "graph/optimizer/OptRule.h"

namespace nebula {
namespace opt {

// Push down the filter items into the child sub-plan of [[CrossJoin]]
class PushFilterDownCrossJoinRule final : public OptRule {
public:
const Pattern &pattern() const override;

StatusOr<OptRule::TransformResult> transform(OptContext *octx,
const MatchedResult &matched) const override;

std::string toString() const override;

private:
PushFilterDownCrossJoinRule();
static OptGroup *pushFilterDownChild(OptContext *octx,
const MatchedResult &child,
const Expression *condition,
Expression **unpickedFilter);

static std::unique_ptr<OptRule> kInstance;
};

} // namespace opt
} // namespace nebula

#endif // GRAPH_OPTIMIZER_RULE_PUSHFILTERDOWNCROSSJOINRULE_H_
3 changes: 1 addition & 2 deletions src/graph/planner/match/MatchSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ Expression* MatchSolver::makeIndexFilter(const std::string& label,

auto* root = relationals[0];
for (auto i = 1u; i < relationals.size(); i++) {
auto* left = root;
root = LogicalExpression::makeAnd(qctx->objPool(), left, relationals[i]);
root = LogicalExpression::makeAnd(qctx->objPool(), root, relationals[i]);
}

return root;
Expand Down
2 changes: 1 addition & 1 deletion src/graph/planner/match/StartVidFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using StartVidFinderInstantiateFunc = std::function<std::unique_ptr<StartVidFind
// 3. PropIndexSeek finds if a plan could traverse from some vids that could be
// read from the property indices.
// MATCH(n:Tag{prop:value}) RETURN n
// MATCH(n:Tag) WHERE n.prop = value RETURN n
// MATCH(n:Tag) WHERE n.Tag.prop = value RETURN n
//
// 4. LabelIndexSeek finds if a plan could traverse from some vids that could be
// read from the label indices.
Expand Down
28 changes: 0 additions & 28 deletions src/graph/planner/plan/Algo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,34 +139,6 @@ std::vector<std::string> CartesianProduct::inputVars() const {
return varNames;
}

std::unique_ptr<PlanNodeDescription> CrossJoin::explain() const {
return BinaryInputNode::explain();
}

PlanNode* CrossJoin::clone() const {
auto* node = make(qctx_);
node->cloneMembers(*this);
return node;
}

void CrossJoin::cloneMembers(const CrossJoin& r) {
BinaryInputNode::cloneMembers(r);
}

CrossJoin::CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right)
: BinaryInputNode(qctx, Kind::kCrossJoin, left, right) {
auto lColNames = left->colNames();
auto rColNames = right->colNames();
lColNames.insert(lColNames.end(), rColNames.begin(), rColNames.end());
setColNames(lColNames);
}

void CrossJoin::accept(PlanNodeVisitor* visitor) {
visitor->visit(this);
}

CrossJoin::CrossJoin(QueryContext* qctx) : BinaryInputNode(qctx, Kind::kCrossJoin) {}

std::unique_ptr<PlanNodeDescription> Subgraph::explain() const {
auto desc = SingleInputNode::explain();
addDescription("src", src_ ? src_->toString() : "", desc.get());
Expand Down
26 changes: 0 additions & 26 deletions src/graph/planner/plan/Algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,32 +437,6 @@ class Subgraph final : public SingleInputNode {
std::unique_ptr<std::vector<EdgeProp>> edgeProps_;
};

class CrossJoin final : public BinaryInputNode {
public:
static CrossJoin* make(QueryContext* qctx, PlanNode* left, PlanNode* right) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx, left, right);
}

std::unique_ptr<PlanNodeDescription> explain() const override;

PlanNode* clone() const override;

void accept(PlanNodeVisitor* visitor) override;

private:
friend ObjectPool;

// used for clone only
static CrossJoin* make(QueryContext* qctx) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx);
}

void cloneMembers(const CrossJoin& r);

CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right);
// use for clone
explicit CrossJoin(QueryContext* qctx);
};
} // namespace graph
} // namespace nebula
#endif // GRAPH_PLANNER_PLAN_ALGO_H_
28 changes: 28 additions & 0 deletions src/graph/planner/plan/Query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,34 @@ void HashInnerJoin::cloneMembers(const HashInnerJoin& l) {
HashJoin::cloneMembers(l);
}

std::unique_ptr<PlanNodeDescription> CrossJoin::explain() const {
return BinaryInputNode::explain();
}

PlanNode* CrossJoin::clone() const {
auto* node = make(qctx_);
node->cloneMembers(*this);
return node;
}

void CrossJoin::cloneMembers(const CrossJoin& r) {
BinaryInputNode::cloneMembers(r);
}

CrossJoin::CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right)
: BinaryInputNode(qctx, Kind::kCrossJoin, left, right) {
auto lColNames = left->colNames();
auto rColNames = right->colNames();
lColNames.insert(lColNames.end(), rColNames.begin(), rColNames.end());
setColNames(lColNames);
}

void CrossJoin::accept(PlanNodeVisitor* visitor) {
visitor->visit(this);
}

CrossJoin::CrossJoin(QueryContext* qctx) : BinaryInputNode(qctx, Kind::kCrossJoin) {}

std::unique_ptr<PlanNodeDescription> RollUpApply::explain() const {
auto desc = BinaryInputNode::explain();
addDescription("compareCols", folly::toJson(util::toJson(compareCols_)), desc.get());
Expand Down
27 changes: 27 additions & 0 deletions src/graph/planner/plan/Query.h
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,33 @@ class HashInnerJoin final : public HashJoin {
void cloneMembers(const HashInnerJoin&);
};

class CrossJoin final : public BinaryInputNode {
public:
static CrossJoin* make(QueryContext* qctx, PlanNode* left, PlanNode* right) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx, left, right);
}

std::unique_ptr<PlanNodeDescription> explain() const override;

PlanNode* clone() const override;

void accept(PlanNodeVisitor* visitor) override;

private:
friend ObjectPool;

// used for clone only
static CrossJoin* make(QueryContext* qctx) {
return qctx->objPool()->makeAndAdd<CrossJoin>(qctx);
}

void cloneMembers(const CrossJoin& r);

CrossJoin(QueryContext* qctx, PlanNode* left, PlanNode* right);
// use for clone
explicit CrossJoin(QueryContext* qctx);
};

// Roll Up Apply two results from two inputs.
class RollUpApply : public BinaryInputNode {
public:
Expand Down
32 changes: 32 additions & 0 deletions tests/tck/features/optimizer/PushFilterDownCrossJoinRule.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023 vesoft inc. All rights reserved.
#
# This source code is licensed under Apache 2.0 License.
Feature: Push Filter down HashInnerJoin rule

Background:
Given a graph with space named "nba"

Scenario: push filter down HashInnerJoin
When profiling query:
"""
with ['Tim Duncan', 'Tony Parker'] as id_list
match (v1:player)-[e]-(v2:player)
where id(v1) in ['Tim Duncan', 'Tony Parker'] AND id(v2) in ['Tim Duncan', 'Tony Parker']
return count(e)
"""
Then the result should be, in any order:
| count(e) |
| 8 |
And the execution plan should be:
| id | name | dependencies | operator info |
| 11 | Aggregate | 14 | |
| 14 | CrossJoin | 1,16 | |
| 1 | Project | 2 | |
| 2 | Start | | |
| 16 | Project | 15 | |
| 15 | Filter | 18 | {"condition": "((id($-.v1) IN [\"Tim Duncan\",\"Tony Parker\"]) AND (id($-.v2) IN [\"Tim Duncan\",\"Tony Parker\"]))"} |
| 18 | AppendVertices | 17 | |
| 17 | Traverse | 4 | |
| 4 | Dedup | 3 | |
| 3 | PassThrough | 5 | |
| 5 | Start | | |