Skip to content

Commit

Permalink
Enhancement/optimize edge all predicate (#5481)
Browse files Browse the repository at this point in the history
  • Loading branch information
czpmango committed Apr 7, 2023
1 parent 7c01128 commit 0659aa2
Show file tree
Hide file tree
Showing 12 changed files with 788 additions and 17 deletions.
20 changes: 20 additions & 0 deletions src/common/meta/SchemaManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,33 @@ class SchemaManager {

virtual StatusOr<int32_t> getPartsNum(GraphSpaceID space) = 0;

std::shared_ptr<const NebulaSchemaProvider> getTagSchema(GraphSpaceID space,
const std::string& tag,
SchemaVer ver = -1) {
auto tagId = toTagID(space, tag);
if (!tagId.ok()) {
return nullptr;
}
return getTagSchema(space, tagId.value(), ver);
}

virtual std::shared_ptr<const NebulaSchemaProvider> getTagSchema(GraphSpaceID space,
TagID tag,
SchemaVer ver = -1) = 0;

// Returns a negative number when the schema does not exist
virtual StatusOr<SchemaVer> getLatestTagSchemaVersion(GraphSpaceID space, TagID tag) = 0;

std::shared_ptr<const NebulaSchemaProvider> getEdgeSchema(GraphSpaceID space,
const std::string& edge,
SchemaVer ver = -1) {
auto edgeType = toEdgeType(space, edge);
if (!edgeType.ok()) {
return nullptr;
}
return getEdgeSchema(space, edgeType.value(), ver);
}

virtual std::shared_ptr<const NebulaSchemaProvider> getEdgeSchema(GraphSpaceID space,
EdgeType edge,
SchemaVer ver = -1) = 0;
Expand Down
8 changes: 6 additions & 2 deletions src/graph/executor/query/TraverseExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ Status TraverseExecutor::buildRequestVids() {
auto vidType = SchemaUtil::propTypeToValueType(metaVidType.get_type());
for (; iter->valid(); iter->next()) {
const auto& vid = src->eval(ctx(iter));
DCHECK_EQ(vid.type(), vidType)
<< "Mismatched vid type: " << vid.type() << ", space vid type: " << vidType;
// FIXME(czp): Remove this DCHECK for now, we should check vid type at compile-time
if (vid.type() != vidType) {
return Status::Error("Vid type mismatched.");
}
// DCHECK_EQ(vid.type(), vidType)
// << "Mismatched vid type: " << vid.type() << ", space vid type: " << vidType;
if (vid.type() == vidType) {
vids_.emplace(vid);
}
Expand Down
4 changes: 3 additions & 1 deletion src/graph/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ nebula_add_library(
rule/PushFilterDownInnerJoinRule.cpp
rule/PushFilterDownNodeRule.cpp
rule/PushFilterDownScanVerticesRule.cpp
rule/PushFilterDownTraverseRule.cpp
rule/PushVFilterDownScanVerticesRule.cpp
rule/OptimizeEdgeIndexScanByFilterRule.cpp
rule/OptimizeTagIndexScanByFilterRule.cpp
Expand All @@ -57,8 +58,9 @@ nebula_add_library(
rule/PushLimitDownScanEdgesAppendVerticesRule.cpp
rule/PushTopNDownIndexScanRule.cpp
rule/PushLimitDownScanEdgesRule.cpp
rule/PushFilterDownTraverseRule.cpp
rule/PushFilterThroughAppendVerticesRule.cpp
rule/RemoveAppendVerticesBelowJoinRule.cpp
rule/EmbedEdgeAllPredIntoTraverseRule.cpp
rule/PushFilterThroughAppendVerticesRule.cpp
)

Expand Down
220 changes: 220 additions & 0 deletions src/graph/optimizer/rule/EmbedEdgeAllPredIntoTraverseRule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#include "graph/optimizer/rule/EmbedEdgeAllPredIntoTraverseRule.h"

#include "common/expression/AttributeExpression.h"
#include "common/expression/ConstantExpression.h"
#include "common/expression/Expression.h"
#include "common/expression/PredicateExpression.h"
#include "common/expression/PropertyExpression.h"
#include "common/expression/VariableExpression.h"
#include "graph/optimizer/OptContext.h"
#include "graph/optimizer/OptGroup.h"
#include "graph/planner/plan/PlanNode.h"
#include "graph/planner/plan/Query.h"
#include "graph/util/ExpressionUtils.h"
#include "graph/visitor/RewriteVisitor.h"

using nebula::Expression;
using nebula::graph::Filter;
using nebula::graph::PlanNode;
using nebula::graph::QueryContext;
using nebula::graph::Traverse;

namespace nebula {
namespace opt {

std::unique_ptr<OptRule> EmbedEdgeAllPredIntoTraverseRule::kInstance =
std::unique_ptr<EmbedEdgeAllPredIntoTraverseRule>(new EmbedEdgeAllPredIntoTraverseRule());

EmbedEdgeAllPredIntoTraverseRule::EmbedEdgeAllPredIntoTraverseRule() {
RuleSet::QueryRules().addRule(this);
}

const Pattern& EmbedEdgeAllPredIntoTraverseRule::pattern() const {
static Pattern pattern =
Pattern::create(PlanNode::Kind::kFilter, {Pattern::create(PlanNode::Kind::kTraverse)});
return pattern;
}

bool EmbedEdgeAllPredIntoTraverseRule::match(OptContext* ctx, const MatchedResult& matched) const {
return OptRule::match(ctx, matched);
}

bool isEdgeAllPredicate(const Expression* e,
const std::string& edgeAlias,
std::string& innerEdgeVar) {
// reset the inner edge var name
innerEdgeVar = "";
if (e->kind() != Expression::Kind::kPredicate) {
return false;
}
auto* pe = static_cast<const PredicateExpression*>(e);
if (pe->name() != "all" || !pe->hasInnerVar()) {
return false;
}
auto var = pe->innerVar();
if (!pe->collection()->isPropertyExpr()) {
return false;
}
// check edge collection expression
if (static_cast<const PropertyExpression*>(pe->collection())->prop() != edgeAlias) {
return false;
}
auto ves = graph::ExpressionUtils::collectAll(pe->filter(), {Expression::Kind::kAttribute});
for (const auto& ve : ves) {
auto iv = static_cast<const AttributeExpression*>(ve)->left();

// check inner vars
if (iv->kind() != Expression::Kind::kVar) {
return false;
}
// only care inner edge vars
if (!static_cast<const VariableExpression*>(iv)->isInner()) {
// FIXME(czp): support parameter/variables in edge `all` predicate
return false;
}

// edge property in AttributeExpression must be Constant string
auto ep = static_cast<const AttributeExpression*>(ve)->right();
if (ep->kind() != Expression::Kind::kConstant) {
return false;
}
if (!static_cast<const ConstantExpression*>(ep)->value().isStr()) {
return false;
}
}

innerEdgeVar = var;
return true;
}

// Pick sub-predicate
// rewrite edge `all` predicates to single-hop edge predicate
Expression* rewriteEdgeAllPredicate(const Expression* expr, const std::string& edgeAlias) {
std::string innerEdgeVar;
auto matcher = [&edgeAlias, &innerEdgeVar](const Expression* e) -> bool {
return isEdgeAllPredicate(e, edgeAlias, innerEdgeVar);
};
auto rewriter = [&innerEdgeVar](const Expression* e) -> Expression* {
DCHECK_EQ(e->kind(), Expression::Kind::kPredicate);
auto fe = static_cast<const PredicateExpression*>(e)->filter();

auto innerMatcher = [&innerEdgeVar](const Expression* ae) {
if (ae->kind() != Expression::Kind::kAttribute) {
return false;
}
auto innerEdgeVarExpr = static_cast<const AttributeExpression*>(ae)->left();
if (innerEdgeVarExpr->kind() != Expression::Kind::kVar) {
return false;
}
return static_cast<const VariableExpression*>(innerEdgeVarExpr)->var() == innerEdgeVar;
};

auto innerRewriter = [](const Expression* ae) {
DCHECK_EQ(ae->kind(), Expression::Kind::kAttribute);
auto attributeExpr = static_cast<const AttributeExpression*>(ae);
auto* right = attributeExpr->right();
// edge property name expressions have been checked in the external matcher
DCHECK_EQ(right->kind(), Expression::Kind::kConstant);
auto& prop = static_cast<const ConstantExpression*>(right)->value().getStr();
return EdgePropertyExpression::make(ae->getObjPool(), "*", prop);
};
// Rewrite all the inner var edge attribute expressions of `all` predicate's oldFilterNode to
// EdgePropertyExpression
return graph::RewriteVisitor::transform(fe, std::move(innerMatcher), std::move(innerRewriter));
};
return graph::RewriteVisitor::transform(expr, std::move(matcher), std::move(rewriter));
}

StatusOr<OptRule::TransformResult> EmbedEdgeAllPredIntoTraverseRule::transform(
OptContext* octx, const MatchedResult& matched) const {
auto* oldFilterGroupNode = matched.node;
auto* oldFilterGroup = oldFilterGroupNode->group();
auto* oldFilterNode = static_cast<graph::Filter*>(oldFilterGroupNode->node());
auto* condition = oldFilterNode->condition();
auto* oldTvGroupNode = matched.dependencies[0].node;
auto* oldTvNode = static_cast<graph::Traverse*>(oldTvGroupNode->node());
auto& edgeAlias = oldTvNode->edgeAlias();
auto qctx = octx->qctx();

// Pick all predicates containing edge `all` predicates under the AND semantics
auto picker = [&edgeAlias](const Expression* expr) -> bool {
bool neverPicked = false;
auto finder = [&neverPicked, &edgeAlias](const Expression* e) -> bool {
if (neverPicked) {
return false;
}
// UnaryNot change the semantics of `all` predicate to `any`, resulting in the inability to
// scatter the edge `all` predicate into a single-hop edge predicate(not cover double-not
// cases)
if (e->kind() == Expression::Kind::kUnaryNot) {
neverPicked = true;
return false;
}
// Not used, the picker only cares if there is an edge `all` predicate in the current operand
std::string innerVar;
return isEdgeAllPredicate(e, edgeAlias, innerVar);
};
graph::FindVisitor visitor(finder);
const_cast<Expression*>(expr)->accept(&visitor);
return !visitor.results().empty();
};
Expression* filterPicked = nullptr;
Expression* filterUnpicked = nullptr;
graph::ExpressionUtils::splitFilter(condition, picker, &filterPicked, &filterUnpicked);

if (!filterPicked) {
return TransformResult::noTransform();
}

// reconnect the existing edge filters
auto* edgeFilter = rewriteEdgeAllPredicate(filterPicked, edgeAlias);
auto* oldEdgeFilter = oldTvNode->eFilter();
Expression* newEdgeFilter =
oldEdgeFilter ? LogicalExpression::makeAnd(
oldEdgeFilter->getObjPool(), edgeFilter, oldEdgeFilter->clone())
: edgeFilter;

// produce new Traverse node
auto* newTvNode = static_cast<graph::Traverse*>(oldTvNode->clone());
newTvNode->setEdgeFilter(newEdgeFilter);
newTvNode->setInputVar(oldTvNode->inputVar());
newTvNode->setColNames(oldTvNode->outputVarPtr()->colNames);

// connect the optimized plan
TransformResult result;
result.eraseAll = true;
if (filterUnpicked) {
// assemble the new Filter node with the old Filter group
auto* newAboveFilterNode = graph::Filter::make(qctx, newTvNode, filterUnpicked);
newAboveFilterNode->setOutputVar(oldFilterNode->outputVar());
newAboveFilterNode->setColNames(oldFilterNode->colNames());
auto newAboveFilterGroupNode = OptGroupNode::create(octx, newAboveFilterNode, oldFilterGroup);
// assemble the new Traverse group below Filter
auto newTvGroup = OptGroup::create(octx);
auto newTvGroupNode = newTvGroup->makeGroupNode(newTvNode);
newTvGroupNode->setDeps(oldTvGroupNode->dependencies());
newAboveFilterGroupNode->setDeps({newTvGroup});
newAboveFilterNode->setInputVar(newTvNode->outputVar());
result.newGroupNodes.emplace_back(newAboveFilterGroupNode);
} else {
// replace the new Traverse node with the old Filter group
auto newTvGroupNode = OptGroupNode::create(octx, newTvNode, oldFilterGroup);
newTvNode->setOutputVar(oldFilterNode->outputVar());
newTvGroupNode->setDeps(oldTvGroupNode->dependencies());
result.newGroupNodes.emplace_back(newTvGroupNode);
}

return result;
}

std::string EmbedEdgeAllPredIntoTraverseRule::toString() const {
return "EmbedEdgeAllPredIntoTraverseRule";
}

} // namespace opt
} // namespace nebula
39 changes: 39 additions & 0 deletions src/graph/optimizer/rule/EmbedEdgeAllPredIntoTraverseRule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2023 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#pragma once

#include "graph/optimizer/OptRule.h"

namespace nebula {
namespace opt {

/*
* Before:
* Filter(all(i in e where i.likeness > 78))
* |
* Traverse
*
* After :
* Traverse(eFilter_: *.likeness > 78)
*/
class EmbedEdgeAllPredIntoTraverseRule final : public OptRule {
public:
const Pattern &pattern() const override;

bool match(OptContext *ctx, const MatchedResult &matched) const override;

StatusOr<TransformResult> transform(OptContext *ctx, const MatchedResult &matched) const override;

std::string toString() const override;

private:
EmbedEdgeAllPredIntoTraverseRule();

static std::unique_ptr<OptRule> kInstance;
};

} // namespace opt
} // namespace nebula
2 changes: 1 addition & 1 deletion src/graph/optimizer/rule/GeoPredicateIndexScanBaseRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ StatusOr<TransformResult> GeoPredicateIndexScanBaseRule::transform(
}
TransformResult result;
result.newGroupNodes.emplace_back(optScanNode);
result.eraseCurr = true;
result.eraseAll = true;
return result;
}

Expand Down
40 changes: 40 additions & 0 deletions src/graph/util/ExpressionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,11 @@ void ExpressionUtils::splitFilter(const Expression *expr,

std::vector<Expression *> &operands = logicExpr->operands();
for (auto &operand : operands) {
// TODO(czp): Sink all NOTs to second layer [[Refactor]]
// TODO(czp): If find any not, dont pick this operand for now
if (ExpressionUtils::findAny(operand, {Expression::Kind::kUnaryNot})) {
filterUnpickedPtr->addOperand(operand->clone());
}
if (picker(operand)) {
filterPickedPtr->addOperand(operand->clone());
} else {
Expand Down Expand Up @@ -1668,5 +1673,40 @@ bool ExpressionUtils::isOneStepEdgeProp(const std::string &edgeAlias, const Expr
return graph::RewriteVisitor::transform(expr, matcher, rewriter);
}

// Transform Label Tag property expression like $-.v.player.name to Tag property like player.name
// for more friendly to push down
// \param pool object pool to hold ownership of objects alloacted
// \param node the name of node, i.e. v in pattern (v)
// \param expr the filter expression
/*static*/ Expression *ExpressionUtils::rewriteVertexPropertyFilter(ObjectPool *pool,
const std::string &node,
Expression *expr) {
graph::RewriteVisitor::Matcher matcher = [&node](const Expression *e) -> bool {
if (e->kind() != Expression::Kind::kLabelTagProperty) {
return false;
}
auto *ltpExpr = static_cast<const LabelTagPropertyExpression *>(e);
auto *labelExpr = ltpExpr->label();
DCHECK(labelExpr->kind() == Expression::Kind::kInputProperty ||
labelExpr->kind() == Expression::Kind::kVarProperty);
if (labelExpr->kind() != Expression::Kind::kInputProperty &&
labelExpr->kind() != Expression::Kind::kVarProperty) {
return false;
}
auto *inputExpr = static_cast<const PropertyExpression *>(labelExpr);
if (inputExpr->prop() != node) {
return false;
}
return true;
};
graph::RewriteVisitor::Rewriter rewriter = [pool](const Expression *e) -> Expression * {
DCHECK_EQ(e->kind(), Expression::Kind::kLabelTagProperty);
auto *ltpExpr = static_cast<const LabelTagPropertyExpression *>(e);
auto *tagPropExpr = TagPropertyExpression::make(pool, ltpExpr->sym(), ltpExpr->prop());
return tagPropExpr;
};
return graph::RewriteVisitor::transform(expr, matcher, rewriter);
}

} // namespace graph
} // namespace nebula
Loading

0 comments on commit 0659aa2

Please sign in to comment.